| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | # | 
					
						
							| 
									
										
										
										
											2024-01-19 19:51:57 +08:00
										 |  |  |  | #  Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | # | 
					
						
							|  |  |  |  | #  Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  |  | #  you may not use this file except in compliance with the License. | 
					
						
							|  |  |  |  | #  You may obtain a copy of the License at | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | #      http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | #  Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  | #  distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  |  | #  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  |  | #  See the License for the specific language governing permissions and | 
					
						
							|  |  |  |  | #  limitations under the License. | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from flask import request | 
					
						
							|  |  |  |  | from flask_login import login_required, current_user | 
					
						
							|  |  |  |  | from api.db.services.dialog_service import DialogService | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |  | from api.db import StatusEnum | 
					
						
							| 
									
										
										
										
											2024-02-01 18:53:56 +08:00
										 |  |  |  | from api.db.services.knowledgebase_service import KnowledgebaseService | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |  | from api.db.services.user_service import TenantService | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | from api.utils.api_utils import server_error_response, get_data_error_result, validate_request | 
					
						
							|  |  |  |  | from api.utils import get_uuid | 
					
						
							|  |  |  |  | from api.utils.api_utils import get_json_result | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | @manager.route('/set', methods=['POST']) | 
					
						
							|  |  |  |  | @login_required | 
					
						
							| 
									
										
										
										
											2024-02-19 19:22:17 +08:00
										 |  |  |  | def set_dialog(): | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |     req = request.json | 
					
						
							|  |  |  |  |     dialog_id = req.get("dialog_id") | 
					
						
							|  |  |  |  |     name = req.get("name", "New Dialog") | 
					
						
							|  |  |  |  |     description = req.get("description", "A helpful Dialog") | 
					
						
							| 
									
										
										
										
											2024-06-06 11:13:39 +08:00
										 |  |  |  |     icon = req.get("icon", "") | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |     top_n = req.get("top_n", 6) | 
					
						
							| 
									
										
										
										
											2024-05-29 16:50:02 +08:00
										 |  |  |  |     top_k = req.get("top_k", 1024) | 
					
						
							|  |  |  |  |     rerank_id = req.get("rerank_id", "") | 
					
						
							|  |  |  |  |     if not rerank_id: req["rerank_id"] = "" | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |     similarity_threshold = req.get("similarity_threshold", 0.1) | 
					
						
							|  |  |  |  |     vector_similarity_weight = req.get("vector_similarity_weight", 0.3) | 
					
						
							| 
									
										
										
										
											2024-05-30 09:25:05 +08:00
										 |  |  |  |     if vector_similarity_weight is None: vector_similarity_weight = 0.3 | 
					
						
							| 
									
										
										
										
											2024-04-30 11:04:14 +08:00
										 |  |  |  |     llm_setting = req.get("llm_setting", {}) | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |     default_prompt = { | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |         "system": """你是一个智能助手,请总结知识库的内容来回答问题,请列举知识库中的数据详细回答。当所有知识库内容都与问题无关时,你的回答必须包括“知识库中未找到您要的答案!”这句话。回答需要考虑聊天历史。
 | 
					
						
							|  |  |  |  | 以下是知识库: | 
					
						
							|  |  |  |  | {knowledge} | 
					
						
							|  |  |  |  | 以上是知识库。""",
 | 
					
						
							|  |  |  |  |         "prologue": "您好,我是您的助手小樱,长得可爱又善良,can I help you?", | 
					
						
							|  |  |  |  |         "parameters": [ | 
					
						
							|  |  |  |  |             {"key": "knowledge", "optional": False} | 
					
						
							|  |  |  |  |         ], | 
					
						
							|  |  |  |  |         "empty_response": "Sorry! 知识库中未找到相关内容!" | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |     } | 
					
						
							|  |  |  |  |     prompt_config = req.get("prompt_config", default_prompt) | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |     if not prompt_config["system"]: | 
					
						
							|  |  |  |  |         prompt_config["system"] = default_prompt["system"] | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |     # if len(prompt_config["parameters"]) < 1: | 
					
						
							|  |  |  |  |     #     prompt_config["parameters"] = default_prompt["parameters"] | 
					
						
							|  |  |  |  |     # for p in prompt_config["parameters"]: | 
					
						
							|  |  |  |  |     #     if p["key"] == "knowledge":break | 
					
						
							|  |  |  |  |     # else: prompt_config["parameters"].append(default_prompt["parameters"][0]) | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     for p in prompt_config["parameters"]: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |         if p["optional"]: | 
					
						
							|  |  |  |  |             continue | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |         if prompt_config["system"].find("{%s}" % p["key"]) < 0: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |             return get_data_error_result( | 
					
						
							|  |  |  |  |                 retmsg="Parameter '{}' is not used".format(p["key"])) | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |     try: | 
					
						
							|  |  |  |  |         e, tenant = TenantService.get_by_id(current_user.id) | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |         if not e: | 
					
						
							|  |  |  |  |             return get_data_error_result(retmsg="Tenant not found!") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |         llm_id = req.get("llm_id", tenant.llm_id) | 
					
						
							|  |  |  |  |         if not dialog_id: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |             if not req.get("kb_ids"): | 
					
						
							|  |  |  |  |                 return get_data_error_result( | 
					
						
							|  |  |  |  |                     retmsg="Fail! Please select knowledgebase!") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |             dia = { | 
					
						
							|  |  |  |  |                 "id": get_uuid(), | 
					
						
							|  |  |  |  |                 "tenant_id": current_user.id, | 
					
						
							|  |  |  |  |                 "name": name, | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |                 "kb_ids": req["kb_ids"], | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |                 "description": description, | 
					
						
							|  |  |  |  |                 "llm_id": llm_id, | 
					
						
							|  |  |  |  |                 "llm_setting": llm_setting, | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |                 "prompt_config": prompt_config, | 
					
						
							|  |  |  |  |                 "top_n": top_n, | 
					
						
							| 
									
										
										
										
											2024-05-29 16:50:02 +08:00
										 |  |  |  |                 "top_k": top_k, | 
					
						
							|  |  |  |  |                 "rerank_id": rerank_id, | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |                 "similarity_threshold": similarity_threshold, | 
					
						
							| 
									
										
										
										
											2024-06-03 13:42:56 +08:00
										 |  |  |  |                 "vector_similarity_weight": vector_similarity_weight, | 
					
						
							| 
									
										
										
										
											2024-06-06 11:13:39 +08:00
										 |  |  |  |                 "icon": icon | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |             } | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |             if not DialogService.save(**dia): | 
					
						
							|  |  |  |  |                 return get_data_error_result(retmsg="Fail to new a dialog!") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |             e, dia = DialogService.get_by_id(dia["id"]) | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |             if not e: | 
					
						
							|  |  |  |  |                 return get_data_error_result(retmsg="Fail to new a dialog!") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |             return get_json_result(data=dia.to_json()) | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             del req["dialog_id"] | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |             if "kb_names" in req: | 
					
						
							|  |  |  |  |                 del req["kb_names"] | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |             if not DialogService.update_by_id(dialog_id, req): | 
					
						
							|  |  |  |  |                 return get_data_error_result(retmsg="Dialog not found!") | 
					
						
							|  |  |  |  |             e, dia = DialogService.get_by_id(dialog_id) | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |             if not e: | 
					
						
							|  |  |  |  |                 return get_data_error_result(retmsg="Fail to update a dialog!") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |             dia = dia.to_dict() | 
					
						
							|  |  |  |  |             dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | 
					
						
							|  |  |  |  |             return get_json_result(data=dia) | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | @manager.route('/get', methods=['GET']) | 
					
						
							|  |  |  |  | @login_required | 
					
						
							|  |  |  |  | def get(): | 
					
						
							|  |  |  |  |     dialog_id = request.args["dialog_id"] | 
					
						
							|  |  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |         e, dia = DialogService.get_by_id(dialog_id) | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |         if not e: | 
					
						
							|  |  |  |  |             return get_data_error_result(retmsg="Dialog not found!") | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |         dia = dia.to_dict() | 
					
						
							|  |  |  |  |         dia["kb_ids"], dia["kb_names"] = get_kb_names(dia["kb_ids"]) | 
					
						
							|  |  |  |  |         return get_json_result(data=dia) | 
					
						
							|  |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | def get_kb_names(kb_ids): | 
					
						
							|  |  |  |  |     ids, nms = [], [] | 
					
						
							|  |  |  |  |     for kid in kb_ids: | 
					
						
							|  |  |  |  |         e, kb = KnowledgebaseService.get_by_id(kid) | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |         if not e or kb.status != StatusEnum.VALID.value: | 
					
						
							|  |  |  |  |             continue | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |         ids.append(kid) | 
					
						
							|  |  |  |  |         nms.append(kb.name) | 
					
						
							|  |  |  |  |     return ids, nms | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  | @manager.route('/list', methods=['GET']) | 
					
						
							|  |  |  |  | @login_required | 
					
						
							| 
									
										
										
										
											2024-05-14 14:48:15 +08:00
										 |  |  |  | def list_dialogs(): | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |         diags = DialogService.query( | 
					
						
							|  |  |  |  |             tenant_id=current_user.id, | 
					
						
							|  |  |  |  |             status=StatusEnum.VALID.value, | 
					
						
							|  |  |  |  |             reverse=True, | 
					
						
							|  |  |  |  |             order_by=DialogService.model.create_time) | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |         diags = [d.to_dict() for d in diags] | 
					
						
							|  |  |  |  |         for d in diags: | 
					
						
							|  |  |  |  |             d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"]) | 
					
						
							|  |  |  |  |         return get_json_result(data=diags) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |  |     except Exception as e: | 
					
						
							|  |  |  |  |         return server_error_response(e) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | @manager.route('/rm', methods=['POST']) | 
					
						
							|  |  |  |  | @login_required | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  | @validate_request("dialog_ids") | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |  | def rm(): | 
					
						
							|  |  |  |  |     req = request.json | 
					
						
							|  |  |  |  |     try: | 
					
						
							| 
									
										
										
										
											2024-03-27 11:33:46 +08:00
										 |  |  |  |         DialogService.update_many_by_id( | 
					
						
							|  |  |  |  |             [{"id": id, "status": StatusEnum.INVALID.value} for id in req["dialog_ids"]]) | 
					
						
							| 
									
										
										
										
											2024-01-18 19:28:37 +08:00
										 |  |  |  |         return get_json_result(data=True) | 
					
						
							| 
									
										
										
										
											2024-01-17 20:20:42 +08:00
										 |  |  |  |     except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-02-21 16:32:38 +08:00
										 |  |  |  |         return server_error_response(e) |