diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 839afdb9fd..a499719fc3 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -133,6 +133,22 @@ class DatasetListApi(DatasetApiResource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() + + if args.get("embedding_model_provider"): + DatasetService.check_embedding_model_setting( + tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") + ) + if ( + args.get("retrieval_model") + and args.get("retrieval_model").get("reranking_model") + and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + ) + try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, @@ -265,10 +281,20 @@ class DatasetApi(DatasetApiResource): data = request.get_json() # check embedding model setting - if data.get("indexing_technique") == "high_quality": + if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"): DatasetService.check_embedding_model_setting( dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") ) + if ( + data.get("retrieval_model") + and data.get("retrieval_model").get("reranking_model") + and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + dataset.tenant_id, + data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + data.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index e4779f3bdf..6213fad173 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -29,7 +29,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment -from services.dataset_service import DocumentService +from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.file_service import FileService @@ -59,6 +59,7 @@ class DocumentAddByTextApi(DatasetApiResource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() + dataset_id = str(dataset_id) tenant_id = str(tenant_id) dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() @@ -74,6 +75,21 @@ class DocumentAddByTextApi(DatasetApiResource): if text is None or name is None: raise ValueError("Both 'text' and 'name' must be non-null values.") + if args.get("embedding_model_provider"): + DatasetService.check_embedding_model_setting( + tenant_id, args.get("embedding_model_provider"), args.get("embedding_model") + ) + if ( + args.get("retrieval_model") + and args.get("retrieval_model").get("reranking_model") + and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + ) + upload_file = FileService.upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", @@ -124,6 +140,17 @@ class DocumentUpdateByTextApi(DatasetApiResource): if not dataset: raise ValueError("Dataset does not exist.") + if ( + args.get("retrieval_model") + and args.get("retrieval_model").get("reranking_model") + and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + tenant_id, + args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"), + args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), + ) + # indexing_technique is already set in dataset since this is an update args["indexing_technique"] = dataset.indexing_technique @@ -188,6 +215,21 @@ class DocumentAddByFileApi(DatasetApiResource): raise ValueError("indexing_technique is required.") args["indexing_technique"] = indexing_technique + if "embedding_model_provider" in args: + DatasetService.check_embedding_model_setting( + tenant_id, args["embedding_model_provider"], args["embedding_model"] + ) + if ( + "retrieval_model" in args + and args["retrieval_model"].get("reranking_model") + and args["retrieval_model"].get("reranking_model").get("reranking_provider_name") + ): + DatasetService.check_reranking_model_setting( + tenant_id, + args["retrieval_model"].get("reranking_model").get("reranking_provider_name"), + args["retrieval_model"].get("reranking_model").get("reranking_model_name"), + ) + # save file info file = request.files["file"] # check file diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index ef511857cf..e42b5ace75 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -278,6 +278,23 @@ class DatasetService: except ProviderTokenNotInitError as ex: raise ValueError(ex.description) + @staticmethod + def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): + try: + model_manager = ModelManager() + model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model_provider, + model_type=ModelType.RERANK, + model=reranking_model, + ) + except LLMBadRequestError: + raise ValueError( + "No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + @staticmethod def update_dataset(dataset_id, data, user): """