| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | import datetime | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from flask import request | 
					
						
							|  |  |  | from flask_login import login_required, current_user | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from rag.app.qa import rmPrefix, beAdoc | 
					
						
							| 
									
										
										
										
											2025-02-26 15:40:52 +08:00
										 |  |  | from rag.app.tag import label_question | 
					
						
							| 
									
										
										
										
											2024-10-22 13:12:49 +08:00
										 |  |  | from rag.nlp import search, rag_tokenizer | 
					
						
							| 
									
										
										
										
											2025-02-26 15:40:52 +08:00
										 |  |  | from rag.prompts import keyword_extraction | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  | from rag.settings import PAGERANK_FLD | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | from rag.utils import rmSpace | 
					
						
							|  |  |  | from api.db import LLMType, ParserType | 
					
						
							|  |  |  | from api.db.services.knowledgebase_service import KnowledgebaseService | 
					
						
							| 
									
										
										
										
											2024-09-18 16:09:22 +08:00
										 |  |  | from api.db.services.llm_service import LLMBundle | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | from api.db.services.user_service import UserTenantService | 
					
						
							|  |  |  | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | 
					
						
							|  |  |  | from api.db.services.document_service import DocumentService | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  | from api import settings | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | from api.utils.api_utils import get_json_result | 
					
						
							| 
									
										
										
										
											2024-12-12 17:47:39 +08:00
										 |  |  | import xxhash | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | import re | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:23:51 +08:00
										 |  |  | @manager.route('/list', methods=['POST'])  # noqa: F821 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("doc_id") | 
					
						
							|  |  |  | def list_chunk(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     doc_id = req["doc_id"] | 
					
						
							|  |  |  |     page = int(req.get("page", 1)) | 
					
						
							|  |  |  |     size = int(req.get("size", 30)) | 
					
						
							|  |  |  |     question = req.get("keywords", "") | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | 
					
						
							|  |  |  |         if not tenant_id: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         e, doc = DocumentService.get_by_id(doc_id) | 
					
						
							|  |  |  |         if not e: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Document not found!") | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         query = { | 
					
						
							|  |  |  |             "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         if "available_int" in req: | 
					
						
							|  |  |  |             query["available_int"] = int(req["available_int"]) | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |         sres = settings.retrievaler.search(query, search.index_name(tenant_id), kb_ids, highlight=True) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} | 
					
						
							|  |  |  |         for id in sres.ids: | 
					
						
							|  |  |  |             d = { | 
					
						
							|  |  |  |                 "chunk_id": id, | 
					
						
							|  |  |  |                 "content_with_weight": rmSpace(sres.highlight[id]) if question and id in sres.highlight else sres.field[ | 
					
						
							|  |  |  |                     id].get( | 
					
						
							|  |  |  |                     "content_with_weight", ""), | 
					
						
							|  |  |  |                 "doc_id": sres.field[id]["doc_id"], | 
					
						
							|  |  |  |                 "docnm_kwd": sres.field[id]["docnm_kwd"], | 
					
						
							|  |  |  |                 "important_kwd": sres.field[id].get("important_kwd", []), | 
					
						
							| 
									
										
										
										
											2024-12-05 14:51:19 +08:00
										 |  |  |                 "question_kwd": sres.field[id].get("question_kwd", []), | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |                 "image_id": sres.field[id].get("img_id", ""), | 
					
						
							| 
									
										
										
										
											2024-12-05 16:49:43 +08:00
										 |  |  |                 "available_int": int(sres.field[id].get("available_int", 1)), | 
					
						
							| 
									
										
										
										
											2024-12-10 16:32:58 +08:00
										 |  |  |                 "positions": sres.field[id].get("position_int", []), | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |             assert isinstance(d["positions"], list) | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |             assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |             res["chunks"].append(d) | 
					
						
							|  |  |  |         return get_json_result(data=res) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         if str(e).find("not_found") > 0: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_json_result(data=False, message='No chunk found!', | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |                                    code=settings.RetCode.DATA_ERROR) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:23:51 +08:00
										 |  |  | @manager.route('/get', methods=['GET'])  # noqa: F821 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | @login_required | 
					
						
							|  |  |  | def get(): | 
					
						
							|  |  |  |     chunk_id = request.args["chunk_id"] | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenants = UserTenantService.query(user_id=current_user.id) | 
					
						
							|  |  |  |         if not tenants: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2025-02-26 10:24:35 +08:00
										 |  |  |         for tenant in tenants: | 
					
						
							|  |  |  |             kb_ids = KnowledgebaseService.get_kb_ids(tenant.tenant_id) | 
					
						
							|  |  |  |             chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant.tenant_id), kb_ids) | 
					
						
							|  |  |  |             if chunk: | 
					
						
							|  |  |  |                 break | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         if chunk is None: | 
					
						
							| 
									
										
										
										
											2024-11-28 18:56:10 +08:00
										 |  |  |             return server_error_response(Exception("Chunk not found")) | 
					
						
							| 
									
										
										
										
											2025-02-26 10:24:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         k = [] | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         for n in chunk.keys(): | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |             if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): | 
					
						
							|  |  |  |                 k.append(n) | 
					
						
							|  |  |  |         for n in k: | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |             del chunk[n] | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         return get_json_result(data=chunk) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     except Exception as e: | 
					
						
							|  |  |  |         if str(e).find("NotFoundError") >= 0: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_json_result(data=False, message='Chunk not found!', | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |                                    code=settings.RetCode.DATA_ERROR) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:23:51 +08:00
										 |  |  | @manager.route('/set', methods=['POST'])  # noqa: F821 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | @login_required | 
					
						
							| 
									
										
										
										
											2025-01-10 19:06:59 +08:00
										 |  |  | @validate_request("doc_id", "chunk_id", "content_with_weight") | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | def set(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     d = { | 
					
						
							|  |  |  |         "id": req["chunk_id"], | 
					
						
							|  |  |  |         "content_with_weight": req["content_with_weight"]} | 
					
						
							|  |  |  |     d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) | 
					
						
							|  |  |  |     d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | 
					
						
							| 
									
										
										
										
											2025-01-10 19:06:59 +08:00
										 |  |  |     if "important_kwd" in req: | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  |         d["important_kwd"] = req["important_kwd"] | 
					
						
							|  |  |  |         d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) | 
					
						
							| 
									
										
										
										
											2025-01-10 19:06:59 +08:00
										 |  |  |     if "question_kwd" in req: | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  |         d["question_kwd"] = req["question_kwd"] | 
					
						
							|  |  |  |         d["question_tks"] = rag_tokenizer.tokenize("\n".join(req["question_kwd"])) | 
					
						
							| 
									
										
										
										
											2025-01-10 19:06:59 +08:00
										 |  |  |     if "tag_kwd" in req: | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  |         d["tag_kwd"] = req["tag_kwd"] | 
					
						
							| 
									
										
										
										
											2025-01-10 19:06:59 +08:00
										 |  |  |     if "tag_feas" in req: | 
					
						
							|  |  |  |         d["tag_feas"] = req["tag_feas"] | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     if "available_int" in req: | 
					
						
							|  |  |  |         d["available_int"] = req["available_int"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | 
					
						
							|  |  |  |         if not tenant_id: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         embd_id = DocumentService.get_embd_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-09-18 16:09:22 +08:00
										 |  |  |         embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_id) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         e, doc = DocumentService.get_by_id(req["doc_id"]) | 
					
						
							|  |  |  |         if not e: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Document not found!") | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if doc.parser_id == ParserType.QA: | 
					
						
							|  |  |  |             arr = [ | 
					
						
							|  |  |  |                 t for t in re.split( | 
					
						
							|  |  |  |                     r"[\n\t]", | 
					
						
							|  |  |  |                     req["content_with_weight"]) if len(t) > 1] | 
					
						
							| 
									
										
										
										
											2024-12-25 19:11:16 +08:00
										 |  |  |             q, a = rmPrefix(arr[0]), rmPrefix("\n".join(arr[1:])) | 
					
						
							| 
									
										
										
										
											2025-01-22 19:43:14 +08:00
										 |  |  |             d = beAdoc(d, q, a, not any( | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |                 [rag_tokenizer.is_chinese(t) for t in q + a])) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-10 19:06:59 +08:00
										 |  |  |         v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d.get("question_kwd") else "\n".join(d["question_kwd"])]) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | 
					
						
							|  |  |  |         d["q_%d_vec" % len(v)] = v.tolist() | 
					
						
							| 
									
										
										
										
											2024-11-28 13:00:38 +08:00
										 |  |  |         settings.docStoreConn.update({"id": req["chunk_id"]}, d, search.index_name(tenant_id), doc.kb_id) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         return get_json_result(data=True) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:23:51 +08:00
										 |  |  | @manager.route('/switch', methods=['POST'])  # noqa: F821 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("chunk_ids", "available_int", "doc_id") | 
					
						
							|  |  |  | def switch(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         e, doc = DocumentService.get_by_id(req["doc_id"]) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(message="Document not found!") | 
					
						
							| 
									
										
										
										
											2024-11-22 12:25:42 +08:00
										 |  |  |         for cid in req["chunk_ids"]: | 
					
						
							|  |  |  |             if not settings.docStoreConn.update({"id": cid}, | 
					
						
							|  |  |  |                                                 {"available_int": int(req["available_int"])}, | 
					
						
							|  |  |  |                                                 search.index_name(DocumentService.get_tenant_id(req["doc_id"])), | 
					
						
							|  |  |  |                                                 doc.kb_id): | 
					
						
							|  |  |  |                 return get_data_error_result(message="Index updating failure") | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         return get_json_result(data=True) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:23:51 +08:00
										 |  |  | @manager.route('/rm', methods=['POST'])  # noqa: F821 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("chunk_ids", "doc_id") | 
					
						
							|  |  |  | def rm(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         e, doc = DocumentService.get_by_id(req["doc_id"]) | 
					
						
							|  |  |  |         if not e: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Document not found!") | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |         if not settings.docStoreConn.delete({"id": req["chunk_ids"]}, search.index_name(current_user.id), doc.kb_id): | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |             return get_data_error_result(message="Index updating failure") | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         deleted_chunk_ids = req["chunk_ids"] | 
					
						
							|  |  |  |         chunk_number = len(deleted_chunk_ids) | 
					
						
							|  |  |  |         DocumentService.decrement_chunk_num(doc.id, doc.kb_id, 1, chunk_number, 0) | 
					
						
							|  |  |  |         return get_json_result(data=True) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:23:51 +08:00
										 |  |  | @manager.route('/create', methods=['POST'])  # noqa: F821 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("doc_id", "content_with_weight") | 
					
						
							|  |  |  | def create(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							| 
									
										
										
										
											2024-12-12 17:47:39 +08:00
										 |  |  |     chunck_id = xxhash.xxh64((req["content_with_weight"] + req["doc_id"]).encode("utf-8")).hexdigest() | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), | 
					
						
							|  |  |  |          "content_with_weight": req["content_with_weight"]} | 
					
						
							|  |  |  |     d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | 
					
						
							|  |  |  |     d["important_kwd"] = req.get("important_kwd", []) | 
					
						
							|  |  |  |     d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", []))) | 
					
						
							| 
									
										
										
										
											2024-12-05 14:51:19 +08:00
										 |  |  |     d["question_kwd"] = req.get("question_kwd", []) | 
					
						
							|  |  |  |     d["question_tks"] = rag_tokenizer.tokenize("\n".join(req.get("question_kwd", []))) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] | 
					
						
							|  |  |  |     d["create_timestamp_flt"] = datetime.datetime.now().timestamp() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         e, doc = DocumentService.get_by_id(req["doc_id"]) | 
					
						
							|  |  |  |         if not e: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Document not found!") | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  |         d["kb_id"] = [doc.kb_id] | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         d["docnm_kwd"] = doc.name | 
					
						
							| 
									
										
										
										
											2024-12-03 14:30:35 +08:00
										 |  |  |         d["title_tks"] = rag_tokenizer.tokenize(doc.name) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         d["doc_id"] = doc.id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | 
					
						
							|  |  |  |         if not tenant_id: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-03 14:30:35 +08:00
										 |  |  |         e, kb = KnowledgebaseService.get_by_id(doc.kb_id) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(message="Knowledgebase not found!") | 
					
						
							| 
									
										
										
										
											2024-12-08 14:21:12 +08:00
										 |  |  |         if kb.pagerank: | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  |             d[PAGERANK_FLD] = kb.pagerank | 
					
						
							| 
									
										
										
										
											2024-12-03 14:30:35 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         embd_id = DocumentService.get_embd_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-09-18 16:09:22 +08:00
										 |  |  |         embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING.value, embd_id) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-05 14:51:19 +08:00
										 |  |  |         v, c = embd_mdl.encode([doc.name, req["content_with_weight"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         v = 0.1 * v[0] + 0.9 * v[1] | 
					
						
							|  |  |  |         d["q_%d_vec" % len(v)] = v.tolist() | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |         settings.docStoreConn.insert([d], search.index_name(tenant_id), doc.kb_id) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         DocumentService.increment_chunk_num( | 
					
						
							|  |  |  |             doc.id, doc.kb_id, c, 1, 0) | 
					
						
							|  |  |  |         return get_json_result(data={"chunk_id": chunck_id}) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:23:51 +08:00
										 |  |  | @manager.route('/retrieval_test', methods=['POST'])  # noqa: F821 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("kb_id", "question") | 
					
						
							|  |  |  | def retrieval_test(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     page = int(req.get("page", 1)) | 
					
						
							|  |  |  |     size = int(req.get("size", 30)) | 
					
						
							|  |  |  |     question = req["question"] | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |     kb_ids = req["kb_id"] | 
					
						
							|  |  |  |     if isinstance(kb_ids, str): | 
					
						
							|  |  |  |         kb_ids = [kb_ids] | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     doc_ids = req.get("doc_ids", []) | 
					
						
							| 
									
										
										
										
											2024-09-12 17:51:20 +08:00
										 |  |  |     similarity_threshold = float(req.get("similarity_threshold", 0.0)) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | 
					
						
							| 
									
										
										
										
											2025-01-22 19:43:14 +08:00
										 |  |  |     use_kg = req.get("use_kg", False) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     top = int(req.get("top_k", 1024)) | 
					
						
							| 
									
										
										
										
											2024-11-20 19:45:50 +08:00
										 |  |  |     tenant_ids = [] | 
					
						
							| 
									
										
										
										
											2024-09-11 19:49:18 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-09-11 19:49:18 +08:00
										 |  |  |         tenants = UserTenantService.query(user_id=current_user.id) | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         for kb_id in kb_ids: | 
					
						
							| 
									
										
										
										
											2024-09-11 19:49:18 +08:00
										 |  |  |             for tenant in tenants: | 
					
						
							|  |  |  |                 if KnowledgebaseService.query( | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |                         tenant_id=tenant.tenant_id, id=kb_id): | 
					
						
							| 
									
										
										
										
											2024-11-20 19:45:50 +08:00
										 |  |  |                     tenant_ids.append(tenant.tenant_id) | 
					
						
							| 
									
										
										
										
											2024-09-11 19:49:18 +08:00
										 |  |  |                     break | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 return get_json_result( | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |                     data=False, message='Only owner of knowledgebase authorized for this operation.', | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |                     code=settings.RetCode.OPERATING_ERROR) | 
					
						
							| 
									
										
										
										
											2024-09-11 19:49:18 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |         e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         if not e: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_data_error_result(message="Knowledgebase not found!") | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-18 16:09:22 +08:00
										 |  |  |         embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         rerank_mdl = None | 
					
						
							|  |  |  |         if req.get("rerank_id"): | 
					
						
							| 
									
										
										
										
											2024-09-18 16:09:22 +08:00
										 |  |  |             rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if req.get("keyword", False): | 
					
						
							| 
									
										
										
										
											2024-09-18 16:09:22 +08:00
										 |  |  |             chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |             question += keyword_extraction(chat_mdl, question) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  |         labels = label_question(question, [kb]) | 
					
						
							| 
									
										
										
										
											2025-01-22 19:43:14 +08:00
										 |  |  |         ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |                                similarity_threshold, vector_similarity_weight, top, | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  |                                doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), | 
					
						
							|  |  |  |                                rank_feature=labels | 
					
						
							|  |  |  |                                ) | 
					
						
							| 
									
										
										
										
											2025-01-22 19:43:14 +08:00
										 |  |  |         if use_kg: | 
					
						
							|  |  |  |             ck = settings.kg_retrievaler.retrieval(question, | 
					
						
							|  |  |  |                                                    tenant_ids, | 
					
						
							|  |  |  |                                                    kb_ids, | 
					
						
							|  |  |  |                                                    embd_mdl, | 
					
						
							|  |  |  |                                                    LLMBundle(kb.tenant_id, LLMType.CHAT)) | 
					
						
							|  |  |  |             if ck["content_with_weight"]: | 
					
						
							|  |  |  |                 ranks["chunks"].insert(0, ck) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         for c in ranks["chunks"]: | 
					
						
							| 
									
										
										
										
											2024-11-19 14:15:25 +08:00
										 |  |  |             c.pop("vector", None) | 
					
						
							| 
									
										
										
										
											2025-01-09 17:07:21 +08:00
										 |  |  |         ranks["labels"] = labels | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return get_json_result(data=ranks) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         if str(e).find("not_found") > 0: | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |             return get_json_result(data=False, message='No chunk found! Check the chunk status please!', | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |                                    code=settings.RetCode.DATA_ERROR) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:23:51 +08:00
										 |  |  | @manager.route('/knowledge_graph', methods=['GET'])  # noqa: F821 | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | @login_required | 
					
						
							|  |  |  | def knowledge_graph(): | 
					
						
							|  |  |  |     doc_id = request.args["doc_id"] | 
					
						
							| 
									
										
										
										
											2024-11-12 14:59:41 +08:00
										 |  |  |     tenant_id = DocumentService.get_tenant_id(doc_id) | 
					
						
							|  |  |  |     kb_ids = KnowledgebaseService.get_kb_ids(tenant_id) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     req = { | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |         "doc_ids": [doc_id], | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |         "knowledge_graph_kwd": ["graph", "mind_map"] | 
					
						
							|  |  |  |     } | 
					
						
							| 
									
										
										
										
											2024-11-15 17:30:56 +08:00
										 |  |  |     sres = settings.retrievaler.search(req, search.index_name(tenant_id), kb_ids) | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  |     obj = {"graph": {}, "mind_map": {}} | 
					
						
							|  |  |  |     for id in sres.ids[:2]: | 
					
						
							|  |  |  |         ty = sres.field[id]["knowledge_graph_kwd"] | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-10-23 10:12:39 +08:00
										 |  |  |             content_json = json.loads(sres.field[id]["content_with_weight"]) | 
					
						
							| 
									
										
										
										
											2024-11-05 11:02:31 +08:00
										 |  |  |         except Exception: | 
					
						
							| 
									
										
										
										
											2024-10-23 10:12:39 +08:00
										 |  |  |             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if ty == 'mind_map': | 
					
						
							|  |  |  |             node_dict = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             def repeat_deal(content_json, node_dict): | 
					
						
							|  |  |  |                 if 'id' in content_json: | 
					
						
							|  |  |  |                     if content_json['id'] in node_dict: | 
					
						
							|  |  |  |                         node_name = content_json['id'] | 
					
						
							|  |  |  |                         content_json['id'] += f"({node_dict[content_json['id']]})" | 
					
						
							|  |  |  |                         node_dict[node_name] += 1 | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         node_dict[content_json['id']] = 1 | 
					
						
							|  |  |  |                 if 'children' in content_json and content_json['children']: | 
					
						
							|  |  |  |                     for item in content_json['children']: | 
					
						
							|  |  |  |                         repeat_deal(item, node_dict) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             repeat_deal(content_json, node_dict) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         obj[ty] = content_json | 
					
						
							| 
									
										
										
										
											2024-08-15 09:17:36 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     return get_json_result(data=obj) |