diff --git a/api/apps/api_app.py b/api/apps/api_app.py index 26df561e9..be712f9b1 100644 --- a/api/apps/api_app.py +++ b/api/apps/api_app.py @@ -18,9 +18,10 @@ import os import re from datetime import datetime, timedelta from flask import request, Response +from api.db.services.llm_service import TenantLLMService from flask_login import login_required, current_user -from api.db import FileType, ParserType, FileSource +from api.db import FileType, LLMType, ParserType, FileSource from api.db.db_models import APIToken, API4Conversation, Task, File from api.db.services import duplicate_name from api.db.services.api_service import APITokenService, API4ConversationService @@ -37,6 +38,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, ge from itsdangerous import URLSafeTimedSerializer from api.utils.file_utils import filename_type, thumbnail +from rag.nlp import keyword_extraction from rag.utils.minio_conn import MINIO from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService @@ -694,7 +696,7 @@ def retrieval(): data=False, retmsg='Token is not valid!"', retcode=RetCode.AUTHENTICATION_ERROR) req = request.json - kb_id = req.get("kb_id") + kb_ids = req.get("kb_id",[]) doc_ids = req.get("doc_ids", []) question = req.get("question") page = int(req.get("page", 1)) @@ -704,32 +706,30 @@ def retrieval(): top = int(req.get("top_k", 1024)) try: - e, kb = KnowledgebaseService.get_by_id(kb_id) - if not e: - return get_data_error_result(retmsg="Knowledgebase not found!") + kbs = KnowledgebaseService.get_by_ids(kb_ids) + embd_nms = list(set([kb.embd_id for kb in kbs])) + if len(embd_nms) != 1: + return get_json_result( + data=False, retmsg='Knowledge bases use different embedding models or does not exist."', retcode=RetCode.AUTHENTICATION_ERROR) embd_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) - + kbs[0].tenant_id, LLMType.EMBEDDING.value, llm_name=kbs[0].embd_id) rerank_mdl = None if req.get("rerank_id"): rerank_mdl = TenantLLMService.model_instance( - kb.tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) - + kbs[0].tenant_id, LLMType.RERANK.value, llm_name=req["rerank_id"]) if req.get("keyword", False): - chat_mdl = TenantLLMService.model_instance(kb.tenant_id, LLMType.CHAT) + chat_mdl = TenantLLMService.model_instance(kbs[0].tenant_id, LLMType.CHAT) question += keyword_extraction(chat_mdl, question) - - ranks = retrievaler.retrieval(question, embd_mdl, kb.tenant_id, [kb_id], page, size, - similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl) + ranks = retrievaler.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size, + similarity_threshold, vector_similarity_weight, top, + doc_ids, rerank_mdl=rerank_mdl) for c in ranks["chunks"]: if "vector" in c: del c["vector"] - return get_json_result(data=ranks) except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, retmsg=f'No chunk found! Check the chunk status please!', retcode=RetCode.DATA_ERROR) - return server_error_response(e) + return server_error_response(e) \ No newline at end of file