| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | # | 
					
						
							| 
									
										
										
										
											2024-01-19 19:51:57 +08:00
										 |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | # | 
					
						
							|  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  | #  limitations under the License. | 
					
						
							|  |  |  | # | 
					
						
							| 
									
										
										
										
											2024-01-19 19:51:57 +08:00
										 |  |  | import datetime | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from flask import request | 
					
						
							|  |  |  | from flask_login import login_required, current_user | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  | from elasticsearch_dsl import Q | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from rag.app.qa import rmPrefix, beAdoc | 
					
						
							| 
									
										
										
										
											2024-04-28 19:13:33 +08:00
										 |  |  | from rag.nlp import search, rag_tokenizer | 
					
						
							| 
									
										
										
										
											2024-04-28 13:19:54 +08:00
										 |  |  | from rag.utils.es_conn import ELASTICSEARCH | 
					
						
							|  |  |  | from rag.utils import rmSpace | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  | from api.db import LLMType, ParserType | 
					
						
							|  |  |  | from api.db.services.knowledgebase_service import KnowledgebaseService | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | from api.db.services.llm_service import TenantLLMService | 
					
						
							|  |  |  | from api.db.services.user_service import UserTenantService | 
					
						
							|  |  |  | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | 
					
						
							|  |  |  | from api.db.services.document_service import DocumentService | 
					
						
							| 
									
										
										
										
											2024-02-27 14:57:34 +08:00
										 |  |  | from api.settings import RetCode, retrievaler | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | from api.utils.api_utils import get_json_result | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  | import hashlib | 
					
						
							|  |  |  | import re | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/list', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("doc_id") | 
					
						
							| 
									
										
										
										
											2024-05-14 14:48:15 +08:00
										 |  |  | def list_chunk(): | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |     req = request.json | 
					
						
							|  |  |  |     doc_id = req["doc_id"] | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |     page = int(req.get("page", 1)) | 
					
						
							|  |  |  |     size = int(req.get("size", 30)) | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |     question = req.get("keywords", "") | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |         tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         if not tenant_id: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-02-07 19:27:23 +08:00
										 |  |  |         e, doc = DocumentService.get_by_id(doc_id) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Document not found!") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |         query = { | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |             "doc_ids": [doc_id], "page": page, "size": size, "question": question, "sort": True | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |         } | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         if "available_int" in req: | 
					
						
							|  |  |  |             query["available_int"] = int(req["available_int"]) | 
					
						
							|  |  |  |         sres = retrievaler.search(query, search.index_name(tenant_id)) | 
					
						
							| 
									
										
										
										
											2024-02-07 19:27:23 +08:00
										 |  |  |         res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()} | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |         for id in sres.ids: | 
					
						
							|  |  |  |             d = { | 
					
						
							|  |  |  |                 "chunk_id": id, | 
					
						
							| 
									
										
										
										
											2024-04-15 14:43:44 +08:00
										 |  |  |                 "content_with_weight": rmSpace(sres.highlight[id]) if question and id in  sres.highlight else sres.field[id].get( | 
					
						
							| 
									
										
										
										
											2024-03-04 17:08:35 +08:00
										 |  |  |                     "content_with_weight", ""), | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |                 "doc_id": sres.field[id]["doc_id"], | 
					
						
							|  |  |  |                 "docnm_kwd": sres.field[id]["docnm_kwd"], | 
					
						
							|  |  |  |                 "important_kwd": sres.field[id].get("important_kwd", []), | 
					
						
							|  |  |  |                 "img_id": sres.field[id].get("img_id", ""), | 
					
						
							|  |  |  |                 "available_int": sres.field[id].get("available_int", 1), | 
					
						
							| 
									
										
										
										
											2024-03-04 14:42:26 +08:00
										 |  |  |                 "positions": sres.field[id].get("position_int", "").split("\t") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |             } | 
					
						
							| 
									
										
										
										
											2024-03-04 17:08:35 +08:00
										 |  |  |             if len(d["positions"]) % 5 == 0: | 
					
						
							|  |  |  |                 poss = [] | 
					
						
							|  |  |  |                 for i in range(0, len(d["positions"]), 5): | 
					
						
							|  |  |  |                     poss.append([float(d["positions"][i]), float(d["positions"][i + 1]), float(d["positions"][i + 2]), | 
					
						
							|  |  |  |                                  float(d["positions"][i + 3]), float(d["positions"][i + 4])]) | 
					
						
							|  |  |  |                 d["positions"] = poss | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |             res["chunks"].append(d) | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |         return get_json_result(data=res) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         if str(e).find("not_found") > 0: | 
					
						
							| 
									
										
										
										
											2024-04-03 11:00:50 +08:00
										 |  |  |             return get_json_result(data=False, retmsg=f'No chunk found!', | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |                                    retcode=RetCode.DATA_ERROR) | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/get', methods=['GET']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | def get(): | 
					
						
							|  |  |  |     chunk_id = request.args["chunk_id"] | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenants = UserTenantService.query(user_id=current_user.id) | 
					
						
							|  |  |  |         if not tenants: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         res = ELASTICSEARCH.get( | 
					
						
							|  |  |  |             chunk_id, search.index_name( | 
					
						
							|  |  |  |                 tenants[0].tenant_id)) | 
					
						
							|  |  |  |         if not res.get("found"): | 
					
						
							|  |  |  |             return server_error_response("Chunk not found") | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |         id = res["_id"] | 
					
						
							|  |  |  |         res = res["_source"] | 
					
						
							|  |  |  |         res["chunk_id"] = id | 
					
						
							|  |  |  |         k = [] | 
					
						
							|  |  |  |         for n in res.keys(): | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  |             if re.search(r"(_vec$|_sm_|_tks|_ltks)", n): | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |                 k.append(n) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         for n in k: | 
					
						
							|  |  |  |             del res[n] | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return get_json_result(data=res) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         if str(e).find("NotFoundError") >= 0: | 
					
						
							|  |  |  |             return get_json_result(data=False, retmsg=f'Chunk not found!', | 
					
						
							|  |  |  |                                    retcode=RetCode.DATA_ERROR) | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/set', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  | @validate_request("doc_id", "chunk_id", "content_with_weight", | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |                   "important_kwd") | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | def set(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |     d = { | 
					
						
							|  |  |  |         "id": req["chunk_id"], | 
					
						
							|  |  |  |         "content_with_weight": req["content_with_weight"]} | 
					
						
							| 
									
										
										
										
											2024-04-28 19:13:33 +08:00
										 |  |  |     d["content_ltks"] = rag_tokenizer.tokenize(req["content_with_weight"]) | 
					
						
							|  |  |  |     d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |     d["important_kwd"] = req["important_kwd"] | 
					
						
							| 
									
										
										
										
											2024-04-28 19:13:33 +08:00
										 |  |  |     d["important_tks"] = rag_tokenizer.tokenize(" ".join(req["important_kwd"])) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |     if "available_int" in req: | 
					
						
							|  |  |  |         d["available_int"] = req["available_int"] | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         if not tenant_id: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-06-19 14:44:25 +08:00
										 |  |  |          | 
					
						
							|  |  |  |         embd_id = DocumentService.get_embd_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         embd_mdl = TenantLLMService.model_instance( | 
					
						
							| 
									
										
										
										
											2024-06-19 14:44:25 +08:00
										 |  |  |             tenant_id, LLMType.EMBEDDING.value, embd_id) | 
					
						
							|  |  |  |          | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         e, doc = DocumentService.get_by_id(req["doc_id"]) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Document not found!") | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if doc.parser_id == ParserType.QA: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |             arr = [ | 
					
						
							|  |  |  |                 t for t in re.split( | 
					
						
							|  |  |  |                     r"[\n\t]", | 
					
						
							|  |  |  |                     req["content_with_weight"]) if len(t) > 1] | 
					
						
							|  |  |  |             if len(arr) != 2: | 
					
						
							|  |  |  |                 return get_data_error_result( | 
					
						
							|  |  |  |                     retmsg="Q&A must be separated by TAB/ENTER key.") | 
					
						
							| 
									
										
										
										
											2024-05-27 08:20:32 +08:00
										 |  |  |             q, a = rmPrefix(arr[0]), rmPrefix(arr[1]) | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |             d = beAdoc(d, arr[0], arr[1], not any( | 
					
						
							| 
									
										
										
										
											2024-04-28 19:13:33 +08:00
										 |  |  |                 [rag_tokenizer.is_chinese(t) for t in q + a])) | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-05 18:08:17 +08:00
										 |  |  |         v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  |         v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         d["q_%d_vec" % len(v)] = v.tolist() | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |         ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | 
					
						
							|  |  |  |         return get_json_result(data=True) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | @manager.route('/switch', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("chunk_ids", "available_int", "doc_id") | 
					
						
							|  |  |  | def switch(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         if not tenant_id: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |         if not ELASTICSEARCH.upsert([{"id": i, "available_int": int(req["available_int"])} for i in req["chunk_ids"]], | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |                                     search.index_name(tenant_id)): | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |             return get_data_error_result(retmsg="Index updating failure") | 
					
						
							|  |  |  |         return get_json_result(data=True) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  | @manager.route('/rm', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("chunk_ids") | 
					
						
							|  |  |  | def rm(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |         if not ELASTICSEARCH.deleteByQuery( | 
					
						
							|  |  |  |                 Q("ids", values=req["chunk_ids"]), search.index_name(current_user.id)): | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |             return get_data_error_result(retmsg="Index updating failure") | 
					
						
							|  |  |  |         return get_json_result(data=True) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | @manager.route('/create', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							| 
									
										
										
										
											2024-02-05 18:08:17 +08:00
										 |  |  | @validate_request("doc_id", "content_with_weight") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | def create(): | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |     req = request.json | 
					
						
							|  |  |  |     md5 = hashlib.md5() | 
					
						
							| 
									
										
										
										
											2024-02-05 18:08:17 +08:00
										 |  |  |     md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8")) | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |     chunck_id = md5.hexdigest() | 
					
						
							| 
									
										
										
										
											2024-04-28 19:13:33 +08:00
										 |  |  |     d = {"id": chunck_id, "content_ltks": rag_tokenizer.tokenize(req["content_with_weight"]), | 
					
						
							| 
									
										
										
										
											2024-03-04 17:08:35 +08:00
										 |  |  |          "content_with_weight": req["content_with_weight"]} | 
					
						
							| 
									
										
										
										
											2024-04-28 19:13:33 +08:00
										 |  |  |     d["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(d["content_ltks"]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |     d["important_kwd"] = req.get("important_kwd", []) | 
					
						
							| 
									
										
										
										
											2024-04-28 19:13:33 +08:00
										 |  |  |     d["important_tks"] = rag_tokenizer.tokenize(" ".join(req.get("important_kwd", []))) | 
					
						
							| 
									
										
										
										
											2024-01-19 19:51:57 +08:00
										 |  |  |     d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19] | 
					
						
							| 
									
										
										
										
											2024-02-07 19:27:23 +08:00
										 |  |  |     d["create_timestamp_flt"] = datetime.datetime.now().timestamp() | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         e, doc = DocumentService.get_by_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Document not found!") | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |         d["kb_id"] = [doc.kb_id] | 
					
						
							|  |  |  |         d["docnm_kwd"] = doc.name | 
					
						
							|  |  |  |         d["doc_id"] = doc.id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         tenant_id = DocumentService.get_tenant_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         if not tenant_id: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-06-20 09:43:39 +08:00
										 |  |  |          | 
					
						
							|  |  |  |         embd_id = DocumentService.get_embd_id(req["doc_id"]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         embd_mdl = TenantLLMService.model_instance( | 
					
						
							| 
									
										
										
										
											2024-06-20 09:43:39 +08:00
										 |  |  |             tenant_id, LLMType.EMBEDDING.value, embd_id) | 
					
						
							|  |  |  |          | 
					
						
							| 
									
										
										
										
											2024-02-05 18:08:17 +08:00
										 |  |  |         v, c = embd_mdl.encode([doc.name, req["content_with_weight"]]) | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |         DocumentService.increment_chunk_num(req["doc_id"], doc.kb_id, c, 1, 0) | 
					
						
							|  |  |  |         v = 0.1 * v[0] + 0.9 * v[1] | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         d["q_%d_vec" % len(v)] = v.tolist() | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |         ELASTICSEARCH.upsert([d], search.index_name(tenant_id)) | 
					
						
							| 
									
										
										
										
											2024-05-27 11:01:20 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         DocumentService.increment_chunk_num( | 
					
						
							|  |  |  |             doc.id, doc.kb_id, c, 1, 0) | 
					
						
							| 
									
										
										
										
											2024-01-17 09:43:27 +08:00
										 |  |  |         return get_json_result(data={"chunk_id": chunck_id}) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         return server_error_response(e) | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @manager.route('/retrieval_test', methods=['POST']) | 
					
						
							|  |  |  | @login_required | 
					
						
							|  |  |  | @validate_request("kb_id", "question") | 
					
						
							|  |  |  | def retrieval_test(): | 
					
						
							|  |  |  |     req = request.json | 
					
						
							|  |  |  |     page = int(req.get("page", 1)) | 
					
						
							|  |  |  |     size = int(req.get("size", 30)) | 
					
						
							|  |  |  |     question = req["question"] | 
					
						
							|  |  |  |     kb_id = req["kb_id"] | 
					
						
							|  |  |  |     doc_ids = req.get("doc_ids", []) | 
					
						
							| 
									
										
										
										
											2024-01-22 19:51:38 +08:00
										 |  |  |     similarity_threshold = float(req.get("similarity_threshold", 0.2)) | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |     vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) | 
					
						
							| 
									
										
										
										
											2024-02-08 17:01:01 +08:00
										 |  |  |     top = int(req.get("top_k", 1024)) | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |     try: | 
					
						
							|  |  |  |         e, kb = KnowledgebaseService.get_by_id(kb_id) | 
					
						
							|  |  |  |         if not e: | 
					
						
							|  |  |  |             return get_data_error_result(retmsg="Knowledgebase not found!") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |         embd_mdl = TenantLLMService.model_instance( | 
					
						
							| 
									
										
										
										
											2024-04-16 16:42:19 +08:00
										 |  |  |             kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) | 
					
						
							| 
									
										
										
										
											2024-05-29 16:50:02 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         rerank_mdl = None | 
					
						
							|  |  |  |         if req.get("rerank_id"): | 
					
						
							|  |  |  |             rerank_mdl = TenantLLMService.model_instance( | 
					
						
							|  |  |  |                 kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, | 
					
						
							|  |  |  |                                       similarity_threshold, vector_similarity_weight, top, | 
					
						
							|  |  |  |                                       doc_ids, rerank_mdl=rerank_mdl) | 
					
						
							| 
									
										
										
										
											2024-02-08 17:01:01 +08:00
										 |  |  |         for c in ranks["chunks"]: | 
					
						
							|  |  |  |             if "vector" in c: | 
					
						
							|  |  |  |                 del c["vector"] | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return get_json_result(data=ranks) | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         if str(e).find("not_found") > 0: | 
					
						
							| 
									
										
										
										
											2024-04-03 11:00:50 +08:00
										 |  |  |             return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |                                    retcode=RetCode.DATA_ERROR) | 
					
						
							|  |  |  |         return server_error_response(e) |