mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-10-30 17:29:40 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			152 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			152 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #  Licensed under the Apache License, Version 2.0 (the "License");
 | ||
| #  you may not use this file except in compliance with the License.
 | ||
| #  You may obtain a copy of the License at
 | ||
| #
 | ||
| #      http://www.apache.org/licenses/LICENSE-2.0
 | ||
| #
 | ||
| #  Unless required by applicable law or agreed to in writing, software
 | ||
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | ||
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | ||
| #  See the License for the specific language governing permissions and
 | ||
| #  limitations under the License.
 | ||
| #
 | ||
| import os
 | ||
| import re
 | ||
| from collections import Counter
 | ||
| from copy import deepcopy
 | ||
| import numpy as np
 | ||
| from huggingface_hub import snapshot_download
 | ||
| 
 | ||
| from api.utils.file_utils import get_project_base_directory
 | ||
| from deepdoc.vision import Recognizer
 | ||
| 
 | ||
| 
 | ||
| class LayoutRecognizer(Recognizer):
 | ||
|     labels = [
 | ||
|         "_background_",
 | ||
|         "Text",
 | ||
|         "Title",
 | ||
|         "Figure",
 | ||
|         "Figure caption",
 | ||
|         "Table",
 | ||
|         "Table caption",
 | ||
|         "Header",
 | ||
|         "Footer",
 | ||
|         "Reference",
 | ||
|         "Equation",
 | ||
|     ]
 | ||
| 
 | ||
|     def __init__(self, domain):
 | ||
|         try:
 | ||
|             model_dir = os.path.join(
 | ||
|                     get_project_base_directory(),
 | ||
|                     "rag/res/deepdoc")
 | ||
|             super().__init__(self.labels, domain, model_dir)
 | ||
|         except Exception as e:
 | ||
|             model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc",
 | ||
|                                           local_dir=os.path.join(get_project_base_directory(), "rag/res/deepdoc"),
 | ||
|                                           local_dir_use_symlinks=False)
 | ||
|             super().__init__(self.labels, domain, model_dir)
 | ||
| 
 | ||
|         self.garbage_layouts = ["footer", "header", "reference"]
 | ||
| 
 | ||
|     def __call__(self, image_list, ocr_res, scale_factor=3,
 | ||
|                  thr=0.2, batch_size=16, drop=True):
 | ||
|         def __is_garbage(b):
 | ||
|             patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
 | ||
|                     r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
 | ||
|                     "(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
 | ||
|                     "\\(cid *: *[0-9]+ *\\)"
 | ||
|                     ]
 | ||
|             return any([re.search(p, b["text"]) for p in patt])
 | ||
| 
 | ||
|         layouts = super().__call__(image_list, thr, batch_size)
 | ||
|         # save_results(image_list, layouts, self.labels, output_dir='output/', threshold=0.7)
 | ||
|         assert len(image_list) == len(ocr_res)
 | ||
|         # Tag layout type
 | ||
|         boxes = []
 | ||
|         assert len(image_list) == len(layouts)
 | ||
|         garbages = {}
 | ||
|         page_layout = []
 | ||
|         for pn, lts in enumerate(layouts):
 | ||
|             bxs = ocr_res[pn]
 | ||
|             lts = [{"type": b["type"],
 | ||
|                     "score": float(b["score"]),
 | ||
|                     "x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
 | ||
|                     "top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
 | ||
|                     "page_number": pn,
 | ||
|                     } for b in lts]
 | ||
|             lts = self.sort_Y_firstly(lts, np.mean(
 | ||
|                 [l["bottom"] - l["top"] for l in lts]) / 2)
 | ||
|             lts = self.layouts_cleanup(bxs, lts)
 | ||
|             page_layout.append(lts)
 | ||
| 
 | ||
|             # Tag layout type, layouts are ready
 | ||
|             def findLayout(ty):
 | ||
|                 nonlocal bxs, lts, self
 | ||
|                 lts_ = [lt for lt in lts if lt["type"] == ty]
 | ||
|                 i = 0
 | ||
|                 while i < len(bxs):
 | ||
|                     if bxs[i].get("layout_type"):
 | ||
|                         i += 1
 | ||
|                         continue
 | ||
|                     if __is_garbage(bxs[i]):
 | ||
|                         bxs.pop(i)
 | ||
|                         continue
 | ||
| 
 | ||
|                     ii = self.find_overlapped_with_threashold(bxs[i], lts_,
 | ||
|                                                               thr=0.4)
 | ||
|                     if ii is None:  # belong to nothing
 | ||
|                         bxs[i]["layout_type"] = ""
 | ||
|                         i += 1
 | ||
|                         continue
 | ||
|                     lts_[ii]["visited"] = True
 | ||
|                     keep_feats = [
 | ||
|                         lts_[
 | ||
|                             ii]["type"] == "footer" and bxs[i]["bottom"] < image_list[pn].size[1] * 0.9 / scale_factor,
 | ||
|                         lts_[
 | ||
|                             ii]["type"] == "header" and bxs[i]["top"] > image_list[pn].size[1] * 0.1 / scale_factor,
 | ||
|                     ]
 | ||
|                     if drop and lts_[
 | ||
|                             ii]["type"] in self.garbage_layouts and not any(keep_feats):
 | ||
|                         if lts_[ii]["type"] not in garbages:
 | ||
|                             garbages[lts_[ii]["type"]] = []
 | ||
|                         garbages[lts_[ii]["type"]].append(bxs[i]["text"])
 | ||
|                         bxs.pop(i)
 | ||
|                         continue
 | ||
| 
 | ||
|                     bxs[i]["layoutno"] = f"{ty}-{ii}"
 | ||
|                     bxs[i]["layout_type"] = lts_[ii]["type"] if lts_[
 | ||
|                         ii]["type"] != "equation" else "figure"
 | ||
|                     i += 1
 | ||
| 
 | ||
|             for lt in ["footer", "header", "reference", "figure caption",
 | ||
|                        "table caption", "title", "table", "text", "figure", "equation"]:
 | ||
|                 findLayout(lt)
 | ||
| 
 | ||
|             # add box to figure layouts which has not text box
 | ||
|             for i, lt in enumerate(
 | ||
|                     [lt for lt in lts if lt["type"] in ["figure", "equation"]]):
 | ||
|                 if lt.get("visited"):
 | ||
|                     continue
 | ||
|                 lt = deepcopy(lt)
 | ||
|                 del lt["type"]
 | ||
|                 lt["text"] = ""
 | ||
|                 lt["layout_type"] = "figure"
 | ||
|                 lt["layoutno"] = f"figure-{i}"
 | ||
|                 bxs.append(lt)
 | ||
| 
 | ||
|             boxes.extend(bxs)
 | ||
| 
 | ||
|         ocr_res = boxes
 | ||
| 
 | ||
|         garbag_set = set()
 | ||
|         for k in garbages.keys():
 | ||
|             garbages[k] = Counter(garbages[k])
 | ||
|             for g, c in garbages[k].items():
 | ||
|                 if c > 1:
 | ||
|                     garbag_set.add(g)
 | ||
| 
 | ||
|         ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
 | ||
|         return ocr_res, page_layout
 | 
