mirror of
				https://github.com/infiniflow/ragflow.git
				synced 2025-10-30 17:29:40 +00:00 
			
		
		
		
	 152072f900
			
		
	
	
		152072f900
		
			
		
	
	
	
	
		
			
			### What problem does this PR solve? #1594 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
		
			
				
	
	
		
			110 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			110 lines
		
	
	
		
			4.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #
 | |
| #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
 | |
| #
 | |
| #  Licensed under the Apache License, Version 2.0 (the "License");
 | |
| #  you may not use this file except in compliance with the License.
 | |
| #  You may obtain a copy of the License at
 | |
| #
 | |
| #      http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| #  Unless required by applicable law or agreed to in writing, software
 | |
| #  distributed under the License is distributed on an "AS IS" BASIS,
 | |
| #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| #  See the License for the specific language governing permissions and
 | |
| #  limitations under the License.
 | |
| #
 | |
| import json
 | |
| from copy import deepcopy
 | |
| 
 | |
| import pandas as pd
 | |
| from elasticsearch_dsl import Q, Search
 | |
| 
 | |
| from rag.nlp.search import Dealer
 | |
| 
 | |
| 
 | |
| class KGSearch(Dealer):
 | |
|     def search(self, req, idxnm, emb_mdl=None):
 | |
|         def merge_into_first(sres, title=""):
 | |
|             df,texts = [],[]
 | |
|             for d in sres["hits"]["hits"]:
 | |
|                 try:
 | |
|                     df.append(json.loads(d["_source"]["content_with_weight"]))
 | |
|                 except Exception as e:
 | |
|                     texts.append(d["_source"]["content_with_weight"])
 | |
|                     pass
 | |
|             if not df and not texts: return False
 | |
|             if df:
 | |
|                 try:
 | |
|                     sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + pd.DataFrame(df).to_csv()
 | |
|                 except Exception as e:
 | |
|                     pass
 | |
|             else:
 | |
|                 sres["hits"]["hits"][0]["_source"]["content_with_weight"] = title + "\n" + "\n".join(texts)
 | |
|             return True
 | |
| 
 | |
|         src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd",
 | |
|                                  "image_id", "doc_id", "q_512_vec", "q_768_vec", "position_int", "name_kwd",
 | |
|                                  "q_1024_vec", "q_1536_vec", "available_int", "content_with_weight",
 | |
|                                  "weight_int", "weight_flt", "rank_int"
 | |
|                                  ])
 | |
| 
 | |
|         qst = req.get("question", "")
 | |
|         binary_query, keywords = self.qryr.question(qst, min_match="5%")
 | |
|         binary_query = self._add_filters(binary_query, req)
 | |
| 
 | |
|         ## Entity retrieval
 | |
|         bqry = deepcopy(binary_query)
 | |
|         bqry.filter.append(Q("terms", knowledge_graph_kwd=["entity"]))
 | |
|         s = Search()
 | |
|         s = s.query(bqry)[0: 32]
 | |
| 
 | |
|         s = s.to_dict()
 | |
|         q_vec = []
 | |
|         if req.get("vector"):
 | |
|             assert emb_mdl, "No embedding model selected"
 | |
|             s["knn"] = self._vector(
 | |
|                 qst, emb_mdl, req.get(
 | |
|                     "similarity", 0.1), 1024)
 | |
|             s["knn"]["filter"] = bqry.to_dict()
 | |
|             q_vec = s["knn"]["query_vector"]
 | |
| 
 | |
|         ent_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
 | |
|         entities = [d["name_kwd"] for d in self.es.getSource(ent_res)]
 | |
|         ent_ids = self.es.getDocIds(ent_res)
 | |
|         if merge_into_first(ent_res, "-Entities-"):
 | |
|             ent_ids = ent_ids[0:1]
 | |
| 
 | |
|         ## Community retrieval
 | |
|         bqry = deepcopy(binary_query)
 | |
|         bqry.filter.append(Q("terms", entities_kwd=entities))
 | |
|         bqry.filter.append(Q("terms", knowledge_graph_kwd=["community_report"]))
 | |
|         s = Search()
 | |
|         s = s.query(bqry)[0: 32]
 | |
|         s = s.to_dict()
 | |
|         comm_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
 | |
|         comm_ids = self.es.getDocIds(comm_res)
 | |
|         if merge_into_first(comm_res, "-Community Report-"):
 | |
|             comm_ids = comm_ids[0:1]
 | |
| 
 | |
|         ## Text content retrieval
 | |
|         bqry = deepcopy(binary_query)
 | |
|         bqry.filter.append(Q("terms", knowledge_graph_kwd=["text"]))
 | |
|         s = Search()
 | |
|         s = s.query(bqry)[0: 6]
 | |
|         s = s.to_dict()
 | |
|         txt_res = self.es.search(deepcopy(s), idxnm=idxnm, timeout="600s", src=src)
 | |
|         txt_ids = self.es.getDocIds(comm_res)
 | |
|         if merge_into_first(txt_res, "-Original Content-"):
 | |
|             txt_ids = comm_ids[0:1]
 | |
| 
 | |
|         return self.SearchResult(
 | |
|             total=len(ent_ids) + len(comm_ids) + len(txt_ids),
 | |
|             ids=[*ent_ids, *comm_ids, *txt_ids],
 | |
|             query_vector=q_vec,
 | |
|             aggregation=None,
 | |
|             highlight=None,
 | |
|             field={**self.getFields(ent_res, src), **self.getFields(comm_res, src), **self.getFields(txt_res, src)},
 | |
|             keywords=[]
 | |
|         )
 | |
| 
 |