mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-10-31 17:59:43 +00:00 
			
		
		
		
	 90ffcb4ddb
			
		
	
	
		90ffcb4ddb
		
			
		
	
	
	
	
		
			
			### What problem does this PR solve? Fix graphrag + infinity bugs ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
		
			
				
	
	
		
			104 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			104 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #
 | |
| #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 | |
| #
 | |
| #  Licensed under the Apache License, Version 2.0 (the "License");
 | |
| #  you may not use this file except in compliance with the License.
 | |
| #  You may obtain a copy of the License at
 | |
| #
 | |
| #      http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| #  Unless required by applicable law or agreed to in writing, software
 | |
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | |
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| #  See the License for the specific language governing permissions and
 | |
| #  limitations under the License.
 | |
| #
 | |
| import json
 | |
| from copy import deepcopy
 | |
| 
 | |
| import pandas as pd
 | |
| from rag.utils.doc_store_conn import OrderByExpr, FusionExpr
 | |
| 
 | |
| from rag.nlp.search import Dealer
 | |
| 
 | |
| 
 | |
| class KGSearch(Dealer):
 | |
|     def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False):
 | |
|         def merge_into_first(sres, title="") -> dict[str, str]:
 | |
|             if not sres:
 | |
|                 return {}
 | |
|             content_with_weight = ""
 | |
|             df, texts = [],[]
 | |
|             for d in sres.values():
 | |
|                 try:
 | |
|                     df.append(json.loads(d["content_with_weight"]))
 | |
|                 except Exception:
 | |
|                     texts.append(d["content_with_weight"])
 | |
|             if df:
 | |
|                 content_with_weight = title + "\n" + pd.DataFrame(df).to_csv()
 | |
|             else:
 | |
|                 content_with_weight = title + "\n" + "\n".join(texts)
 | |
|             first_id = ""
 | |
|             first_source = {}
 | |
|             for k, v in sres.items():
 | |
|                 first_id = id
 | |
|                 first_source = deepcopy(v)
 | |
|                 break
 | |
|             first_source["content_with_weight"] = content_with_weight
 | |
|             first_id = next(iter(sres))
 | |
|             return {first_id: first_source}
 | |
| 
 | |
|         qst = req.get("question", "")
 | |
|         matchText, keywords = self.qryr.question(qst, min_match=0.05)
 | |
|         condition = self.get_filters(req)
 | |
| 
 | |
|         ## Entity retrieval
 | |
|         condition.update({"knowledge_graph_kwd": ["entity"]})
 | |
|         assert emb_mdl, "No embedding model selected"
 | |
|         matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1))
 | |
|         q_vec = matchDense.embedding_data
 | |
|         src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
 | |
|                                  "doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd",
 | |
|                                  "available_int", "content_with_weight",
 | |
|                                  "weight_int", "weight_flt"
 | |
|                                  ])
 | |
| 
 | |
|         fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"})
 | |
| 
 | |
|         ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
 | |
|         ent_res_fields = self.dataStore.getFields(ent_res, src)
 | |
|         entities = [d["name_kwd"] for d in ent_res_fields.values() if d.get("name_kwd")]
 | |
|         ent_ids = self.dataStore.getChunkIds(ent_res)
 | |
|         ent_content = merge_into_first(ent_res_fields, "-Entities-")
 | |
|         if ent_content:
 | |
|             ent_ids = list(ent_content.keys())
 | |
| 
 | |
|         ## Community retrieval
 | |
|         condition = self.get_filters(req)
 | |
|         condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]})
 | |
|         comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids)
 | |
|         comm_res_fields = self.dataStore.getFields(comm_res, src)
 | |
|         comm_ids = self.dataStore.getChunkIds(comm_res)
 | |
|         comm_content = merge_into_first(comm_res_fields, "-Community Report-")
 | |
|         if comm_content:
 | |
|             comm_ids = list(comm_content.keys())
 | |
| 
 | |
|         ## Text content retrieval
 | |
|         condition = self.get_filters(req)
 | |
|         condition.update({"knowledge_graph_kwd": ["text"]})
 | |
|         txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids)
 | |
|         txt_res_fields = self.dataStore.getFields(txt_res, src)
 | |
|         txt_ids = self.dataStore.getChunkIds(txt_res)
 | |
|         txt_content = merge_into_first(txt_res_fields, "-Original Content-")
 | |
|         if txt_content:
 | |
|             txt_ids = list(txt_content.keys())
 | |
| 
 | |
|         return self.SearchResult(
 | |
|             total=len(ent_ids) + len(comm_ids) + len(txt_ids),
 | |
|             ids=[*ent_ids, *comm_ids, *txt_ids],
 | |
|             query_vector=q_vec,
 | |
|             highlight=None,
 | |
|             field={**ent_content, **comm_content, **txt_content},
 | |
|             keywords=[]
 | |
|         )
 |