| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | from copy import deepcopy | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import pandas as pd | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  | from rag.utils.doc_store_conn import OrderByExpr, FusionExpr | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from rag.nlp.search import Dealer | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class KGSearch(Dealer): | 
					
						
							| 
									
										
										
										
											2024-11-20 09:31:36 +08:00
										 |  |  |     def search(self, req, idxnm: str | list[str], kb_ids: list[str], emb_mdl=None, highlight=False): | 
					
						
							| 
									
										
										
										
											2024-11-18 17:38:17 +08:00
										 |  |  |         def merge_into_first(sres, title="") -> dict[str, str]: | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |             if not sres: | 
					
						
							|  |  |  |                 return {} | 
					
						
							|  |  |  |             content_with_weight = "" | 
					
						
							|  |  |  |             df, texts = [],[] | 
					
						
							|  |  |  |             for d in sres.values(): | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |                 try: | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |                     df.append(json.loads(d["content_with_weight"])) | 
					
						
							|  |  |  |                 except Exception: | 
					
						
							|  |  |  |                     texts.append(d["content_with_weight"]) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |             if df: | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |                 content_with_weight = title + "\n" + pd.DataFrame(df).to_csv() | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |                 content_with_weight = title + "\n" + "\n".join(texts) | 
					
						
							|  |  |  |             first_id = "" | 
					
						
							|  |  |  |             first_source = {} | 
					
						
							|  |  |  |             for k, v in sres.items(): | 
					
						
							|  |  |  |                 first_id = id | 
					
						
							|  |  |  |                 first_source = deepcopy(v) | 
					
						
							|  |  |  |                 break | 
					
						
							|  |  |  |             first_source["content_with_weight"] = content_with_weight | 
					
						
							|  |  |  |             first_id = next(iter(sres)) | 
					
						
							|  |  |  |             return {first_id: first_source} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         qst = req.get("question", "") | 
					
						
							|  |  |  |         matchText, keywords = self.qryr.question(qst, min_match=0.05) | 
					
						
							|  |  |  |         condition = self.get_filters(req) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         ## Entity retrieval | 
					
						
							|  |  |  |         condition.update({"knowledge_graph_kwd": ["entity"]}) | 
					
						
							|  |  |  |         assert emb_mdl, "No embedding model selected" | 
					
						
							|  |  |  |         matchDense = self.get_vector(qst, emb_mdl, 1024, req.get("similarity", 0.1)) | 
					
						
							|  |  |  |         q_vec = matchDense.embedding_data | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |         src = req.get("fields", ["docnm_kwd", "content_ltks", "kb_id", "img_id", "title_tks", "important_kwd", | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |                                  "doc_id", f"q_{len(q_vec)}_vec", "position_list", "name_kwd", | 
					
						
							| 
									
										
										
										
											2024-11-27 12:45:43 +08:00
										 |  |  |                                  "available_int", "content_with_weight", | 
					
						
							|  |  |  |                                  "weight_int", "weight_flt" | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |                                  ]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         fusionExpr = FusionExpr("weighted_sum", 32, {"weights": "0.5, 0.5"}) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         ent_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids) | 
					
						
							|  |  |  |         ent_res_fields = self.dataStore.getFields(ent_res, src) | 
					
						
							| 
									
										
										
										
											2024-11-21 11:13:29 +08:00
										 |  |  |         entities = [d["name_kwd"] for d in ent_res_fields.values() if d.get("name_kwd")] | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         ent_ids = self.dataStore.getChunkIds(ent_res) | 
					
						
							|  |  |  |         ent_content = merge_into_first(ent_res_fields, "-Entities-") | 
					
						
							|  |  |  |         if ent_content: | 
					
						
							|  |  |  |             ent_ids = list(ent_content.keys()) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         ## Community retrieval | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         condition = self.get_filters(req) | 
					
						
							|  |  |  |         condition.update({"entities_kwd": entities, "knowledge_graph_kwd": ["community_report"]}) | 
					
						
							|  |  |  |         comm_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 32, idxnm, kb_ids) | 
					
						
							|  |  |  |         comm_res_fields = self.dataStore.getFields(comm_res, src) | 
					
						
							|  |  |  |         comm_ids = self.dataStore.getChunkIds(comm_res) | 
					
						
							|  |  |  |         comm_content = merge_into_first(comm_res_fields, "-Community Report-") | 
					
						
							|  |  |  |         if comm_content: | 
					
						
							|  |  |  |             comm_ids = list(comm_content.keys()) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         ## Text content retrieval | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         condition = self.get_filters(req) | 
					
						
							|  |  |  |         condition.update({"knowledge_graph_kwd": ["text"]}) | 
					
						
							|  |  |  |         txt_res = self.dataStore.search(src, list(), condition, [matchText, matchDense, fusionExpr], OrderByExpr(), 0, 6, idxnm, kb_ids) | 
					
						
							|  |  |  |         txt_res_fields = self.dataStore.getFields(txt_res, src) | 
					
						
							|  |  |  |         txt_ids = self.dataStore.getChunkIds(txt_res) | 
					
						
							|  |  |  |         txt_content = merge_into_first(txt_res_fields, "-Original Content-") | 
					
						
							|  |  |  |         if txt_content: | 
					
						
							|  |  |  |             txt_ids = list(txt_content.keys()) | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return self.SearchResult( | 
					
						
							|  |  |  |             total=len(ent_ids) + len(comm_ids) + len(txt_ids), | 
					
						
							|  |  |  |             ids=[*ent_ids, *comm_ids, *txt_ids], | 
					
						
							|  |  |  |             query_vector=q_vec, | 
					
						
							|  |  |  |             highlight=None, | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |             field={**ent_content, **comm_content, **txt_content}, | 
					
						
							| 
									
										
										
										
											2024-08-02 18:51:14 +08:00
										 |  |  |             keywords=[] | 
					
						
							|  |  |  |         ) |