mirror of
				https://github.com/langgenius/dify.git
				synced 2025-10-31 02:42:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			750 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			750 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import flask_restful
 | |
| from flask import request
 | |
| from flask_login import current_user
 | |
| from flask_restful import Resource, marshal, marshal_with, reqparse
 | |
| from werkzeug.exceptions import Forbidden, NotFound
 | |
| 
 | |
| import services
 | |
| from configs import dify_config
 | |
| from controllers.console import api
 | |
| from controllers.console.apikey import api_key_fields, api_key_list
 | |
| from controllers.console.app.error import ProviderNotInitializeError
 | |
| from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
 | |
| from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
 | |
| from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 | |
| from core.indexing_runner import IndexingRunner
 | |
| from core.model_runtime.entities.model_entities import ModelType
 | |
| from core.provider_manager import ProviderManager
 | |
| from core.rag.datasource.vdb.vector_type import VectorType
 | |
| from core.rag.extractor.entity.extract_setting import ExtractSetting
 | |
| from core.rag.retrieval.retrieval_methods import RetrievalMethod
 | |
| from extensions.ext_database import db
 | |
| from fields.app_fields import related_app_list
 | |
| from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
 | |
| from fields.document_fields import document_status_fields
 | |
| from libs.login import login_required
 | |
| from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
 | |
| from models.dataset import DatasetPermissionEnum
 | |
| from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
 | |
| 
 | |
| 
 | |
| def _validate_name(name):
 | |
|     if not name or len(name) < 1 or len(name) > 40:
 | |
|         raise ValueError("Name must be between 1 to 40 characters.")
 | |
|     return name
 | |
| 
 | |
| 
 | |
| def _validate_description_length(description):
 | |
|     if len(description) > 400:
 | |
|         raise ValueError("Description cannot exceed 400 characters.")
 | |
|     return description
 | |
| 
 | |
| 
 | |
| class DatasetListApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @enterprise_license_required
 | |
|     def get(self):
 | |
|         page = request.args.get("page", default=1, type=int)
 | |
|         limit = request.args.get("limit", default=20, type=int)
 | |
|         ids = request.args.getlist("ids")
 | |
|         # provider = request.args.get("provider", default="vendor")
 | |
|         search = request.args.get("keyword", default=None, type=str)
 | |
|         tag_ids = request.args.getlist("tag_ids")
 | |
| 
 | |
|         if ids:
 | |
|             datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id)
 | |
|         else:
 | |
|             datasets, total = DatasetService.get_datasets(
 | |
|                 page, limit, current_user.current_tenant_id, current_user, search, tag_ids
 | |
|             )
 | |
| 
 | |
|         # check embedding setting
 | |
|         provider_manager = ProviderManager()
 | |
|         configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
 | |
| 
 | |
|         embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
 | |
| 
 | |
|         model_names = []
 | |
|         for embedding_model in embedding_models:
 | |
|             model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
 | |
| 
 | |
|         data = marshal(datasets, dataset_detail_fields)
 | |
|         for item in data:
 | |
|             if item["indexing_technique"] == "high_quality":
 | |
|                 item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
 | |
|                 if item_model in model_names:
 | |
|                     item["embedding_available"] = True
 | |
|                 else:
 | |
|                     item["embedding_available"] = False
 | |
|             else:
 | |
|                 item["embedding_available"] = True
 | |
| 
 | |
|             if item.get("permission") == "partial_members":
 | |
|                 part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"])
 | |
|                 item.update({"partial_member_list": part_users_list})
 | |
|             else:
 | |
|                 item.update({"partial_member_list": []})
 | |
| 
 | |
|         response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
 | |
|         return response, 200
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument(
 | |
|             "name",
 | |
|             nullable=False,
 | |
|             required=True,
 | |
|             help="type is required. Name must be between 1 to 40 characters.",
 | |
|             type=_validate_name,
 | |
|         )
 | |
|         parser.add_argument(
 | |
|             "description",
 | |
|             type=str,
 | |
|             nullable=True,
 | |
|             required=False,
 | |
|             default="",
 | |
|         )
 | |
|         parser.add_argument(
 | |
|             "indexing_technique",
 | |
|             type=str,
 | |
|             location="json",
 | |
|             choices=Dataset.INDEXING_TECHNIQUE_LIST,
 | |
|             nullable=True,
 | |
|             help="Invalid indexing technique.",
 | |
|         )
 | |
|         parser.add_argument(
 | |
|             "external_knowledge_api_id",
 | |
|             type=str,
 | |
|             nullable=True,
 | |
|             required=False,
 | |
|         )
 | |
|         parser.add_argument(
 | |
|             "provider",
 | |
|             type=str,
 | |
|             nullable=True,
 | |
|             choices=Dataset.PROVIDER_LIST,
 | |
|             required=False,
 | |
|             default="vendor",
 | |
|         )
 | |
|         parser.add_argument(
 | |
|             "external_knowledge_id",
 | |
|             type=str,
 | |
|             nullable=True,
 | |
|             required=False,
 | |
|         )
 | |
|         args = parser.parse_args()
 | |
| 
 | |
|         # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
 | |
|         if not current_user.is_dataset_editor:
 | |
|             raise Forbidden()
 | |
| 
 | |
|         try:
 | |
|             dataset = DatasetService.create_empty_dataset(
 | |
|                 tenant_id=current_user.current_tenant_id,
 | |
|                 name=args["name"],
 | |
|                 description=args["description"],
 | |
|                 indexing_technique=args["indexing_technique"],
 | |
|                 account=current_user,
 | |
|                 permission=DatasetPermissionEnum.ONLY_ME,
 | |
|                 provider=args["provider"],
 | |
|                 external_knowledge_api_id=args["external_knowledge_api_id"],
 | |
|                 external_knowledge_id=args["external_knowledge_id"],
 | |
|             )
 | |
|         except services.errors.dataset.DatasetNameDuplicateError:
 | |
|             raise DatasetNameDuplicateError()
 | |
| 
 | |
|         return marshal(dataset, dataset_detail_fields), 201
 | |
| 
 | |
| 
 | |
| class DatasetApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self, dataset_id):
 | |
|         dataset_id_str = str(dataset_id)
 | |
|         dataset = DatasetService.get_dataset(dataset_id_str)
 | |
|         if dataset is None:
 | |
|             raise NotFound("Dataset not found.")
 | |
|         try:
 | |
|             DatasetService.check_dataset_permission(dataset, current_user)
 | |
|         except services.errors.account.NoPermissionError as e:
 | |
|             raise Forbidden(str(e))
 | |
|         data = marshal(dataset, dataset_detail_fields)
 | |
|         if data.get("permission") == "partial_members":
 | |
|             part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
 | |
|             data.update({"partial_member_list": part_users_list})
 | |
| 
 | |
|         # check embedding setting
 | |
|         provider_manager = ProviderManager()
 | |
|         configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id)
 | |
| 
 | |
|         embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True)
 | |
| 
 | |
|         model_names = []
 | |
|         for embedding_model in embedding_models:
 | |
|             model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}")
 | |
| 
 | |
|         if data["indexing_technique"] == "high_quality":
 | |
|             item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}"
 | |
|             if item_model in model_names:
 | |
|                 data["embedding_available"] = True
 | |
|             else:
 | |
|                 data["embedding_available"] = False
 | |
|         else:
 | |
|             data["embedding_available"] = True
 | |
| 
 | |
|         if data.get("permission") == "partial_members":
 | |
|             part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
 | |
|             data.update({"partial_member_list": part_users_list})
 | |
| 
 | |
|         return data, 200
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def patch(self, dataset_id):
 | |
|         dataset_id_str = str(dataset_id)
 | |
|         dataset = DatasetService.get_dataset(dataset_id_str)
 | |
|         if dataset is None:
 | |
|             raise NotFound("Dataset not found.")
 | |
| 
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument(
 | |
|             "name",
 | |
|             nullable=False,
 | |
|             help="type is required. Name must be between 1 to 40 characters.",
 | |
|             type=_validate_name,
 | |
|         )
 | |
|         parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length)
 | |
|         parser.add_argument(
 | |
|             "indexing_technique",
 | |
|             type=str,
 | |
|             location="json",
 | |
|             choices=Dataset.INDEXING_TECHNIQUE_LIST,
 | |
|             nullable=True,
 | |
|             help="Invalid indexing technique.",
 | |
|         )
 | |
|         parser.add_argument(
 | |
|             "permission",
 | |
|             type=str,
 | |
|             location="json",
 | |
|             choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
 | |
|             help="Invalid permission.",
 | |
|         )
 | |
|         parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.")
 | |
|         parser.add_argument(
 | |
|             "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider."
 | |
|         )
 | |
|         parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.")
 | |
|         parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.")
 | |
| 
 | |
|         parser.add_argument(
 | |
|             "external_retrieval_model",
 | |
|             type=dict,
 | |
|             required=False,
 | |
|             nullable=True,
 | |
|             location="json",
 | |
|             help="Invalid external retrieval model.",
 | |
|         )
 | |
| 
 | |
|         parser.add_argument(
 | |
|             "external_knowledge_id",
 | |
|             type=str,
 | |
|             required=False,
 | |
|             nullable=True,
 | |
|             location="json",
 | |
|             help="Invalid external knowledge id.",
 | |
|         )
 | |
| 
 | |
|         parser.add_argument(
 | |
|             "external_knowledge_api_id",
 | |
|             type=str,
 | |
|             required=False,
 | |
|             nullable=True,
 | |
|             location="json",
 | |
|             help="Invalid external knowledge api id.",
 | |
|         )
 | |
|         args = parser.parse_args()
 | |
|         data = request.get_json()
 | |
| 
 | |
|         # check embedding model setting
 | |
|         if data.get("indexing_technique") == "high_quality":
 | |
|             DatasetService.check_embedding_model_setting(
 | |
|                 dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
 | |
|             )
 | |
| 
 | |
|         # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
 | |
|         DatasetPermissionService.check_permission(
 | |
|             current_user, dataset, data.get("permission"), data.get("partial_member_list")
 | |
|         )
 | |
| 
 | |
|         dataset = DatasetService.update_dataset(dataset_id_str, args, current_user)
 | |
| 
 | |
|         if dataset is None:
 | |
|             raise NotFound("Dataset not found.")
 | |
| 
 | |
|         result_data = marshal(dataset, dataset_detail_fields)
 | |
|         tenant_id = current_user.current_tenant_id
 | |
| 
 | |
|         if data.get("partial_member_list") and data.get("permission") == "partial_members":
 | |
|             DatasetPermissionService.update_partial_member_list(
 | |
|                 tenant_id, dataset_id_str, data.get("partial_member_list")
 | |
|             )
 | |
|         # clear partial member list when permission is only_me or all_team_members
 | |
|         elif (
 | |
|             data.get("permission") == DatasetPermissionEnum.ONLY_ME
 | |
|             or data.get("permission") == DatasetPermissionEnum.ALL_TEAM
 | |
|         ):
 | |
|             DatasetPermissionService.clear_partial_member_list(dataset_id_str)
 | |
| 
 | |
|         partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
 | |
|         result_data.update({"partial_member_list": partial_member_list})
 | |
| 
 | |
|         return result_data, 200
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def delete(self, dataset_id):
 | |
|         dataset_id_str = str(dataset_id)
 | |
| 
 | |
|         # The role of the current user in the ta table must be admin, owner, or editor
 | |
|         if not current_user.is_editor or current_user.is_dataset_operator:
 | |
|             raise Forbidden()
 | |
| 
 | |
|         try:
 | |
|             if DatasetService.delete_dataset(dataset_id_str, current_user):
 | |
|                 DatasetPermissionService.clear_partial_member_list(dataset_id_str)
 | |
|                 return {"result": "success"}, 204
 | |
|             else:
 | |
|                 raise NotFound("Dataset not found.")
 | |
|         except services.errors.dataset.DatasetInUseError:
 | |
|             raise DatasetInUseError()
 | |
| 
 | |
| 
 | |
| class DatasetUseCheckApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self, dataset_id):
 | |
|         dataset_id_str = str(dataset_id)
 | |
| 
 | |
|         dataset_is_using = DatasetService.dataset_use_check(dataset_id_str)
 | |
|         return {"is_using": dataset_is_using}, 200
 | |
| 
 | |
| 
 | |
| class DatasetQueryApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self, dataset_id):
 | |
|         dataset_id_str = str(dataset_id)
 | |
|         dataset = DatasetService.get_dataset(dataset_id_str)
 | |
|         if dataset is None:
 | |
|             raise NotFound("Dataset not found.")
 | |
| 
 | |
|         try:
 | |
|             DatasetService.check_dataset_permission(dataset, current_user)
 | |
|         except services.errors.account.NoPermissionError as e:
 | |
|             raise Forbidden(str(e))
 | |
| 
 | |
|         page = request.args.get("page", default=1, type=int)
 | |
|         limit = request.args.get("limit", default=20, type=int)
 | |
| 
 | |
|         dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit)
 | |
| 
 | |
|         response = {
 | |
|             "data": marshal(dataset_queries, dataset_query_detail_fields),
 | |
|             "has_more": len(dataset_queries) == limit,
 | |
|             "limit": limit,
 | |
|             "total": total,
 | |
|             "page": page,
 | |
|         }
 | |
|         return response, 200
 | |
| 
 | |
| 
 | |
| class DatasetIndexingEstimateApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def post(self):
 | |
|         parser = reqparse.RequestParser()
 | |
|         parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json")
 | |
|         parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json")
 | |
|         parser.add_argument(
 | |
|             "indexing_technique",
 | |
|             type=str,
 | |
|             required=True,
 | |
|             choices=Dataset.INDEXING_TECHNIQUE_LIST,
 | |
|             nullable=True,
 | |
|             location="json",
 | |
|         )
 | |
|         parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
 | |
|         parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json")
 | |
|         parser.add_argument(
 | |
|             "doc_language", type=str, default="English", required=False, nullable=False, location="json"
 | |
|         )
 | |
|         args = parser.parse_args()
 | |
|         # validate args
 | |
|         DocumentService.estimate_args_validate(args)
 | |
|         extract_settings = []
 | |
|         if args["info_list"]["data_source_type"] == "upload_file":
 | |
|             file_ids = args["info_list"]["file_info_list"]["file_ids"]
 | |
|             file_details = (
 | |
|                 db.session.query(UploadFile)
 | |
|                 .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
 | |
|                 .all()
 | |
|             )
 | |
| 
 | |
|             if file_details is None:
 | |
|                 raise NotFound("File not found.")
 | |
| 
 | |
|             if file_details:
 | |
|                 for file_detail in file_details:
 | |
|                     extract_setting = ExtractSetting(
 | |
|                         datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"]
 | |
|                     )
 | |
|                     extract_settings.append(extract_setting)
 | |
|         elif args["info_list"]["data_source_type"] == "notion_import":
 | |
|             notion_info_list = args["info_list"]["notion_info_list"]
 | |
|             for notion_info in notion_info_list:
 | |
|                 workspace_id = notion_info["workspace_id"]
 | |
|                 for page in notion_info["pages"]:
 | |
|                     extract_setting = ExtractSetting(
 | |
|                         datasource_type="notion_import",
 | |
|                         notion_info={
 | |
|                             "notion_workspace_id": workspace_id,
 | |
|                             "notion_obj_id": page["page_id"],
 | |
|                             "notion_page_type": page["type"],
 | |
|                             "tenant_id": current_user.current_tenant_id,
 | |
|                         },
 | |
|                         document_model=args["doc_form"],
 | |
|                     )
 | |
|                     extract_settings.append(extract_setting)
 | |
|         elif args["info_list"]["data_source_type"] == "website_crawl":
 | |
|             website_info_list = args["info_list"]["website_info_list"]
 | |
|             for url in website_info_list["urls"]:
 | |
|                 extract_setting = ExtractSetting(
 | |
|                     datasource_type="website_crawl",
 | |
|                     website_info={
 | |
|                         "provider": website_info_list["provider"],
 | |
|                         "job_id": website_info_list["job_id"],
 | |
|                         "url": url,
 | |
|                         "tenant_id": current_user.current_tenant_id,
 | |
|                         "mode": "crawl",
 | |
|                         "only_main_content": website_info_list["only_main_content"],
 | |
|                     },
 | |
|                     document_model=args["doc_form"],
 | |
|                 )
 | |
|                 extract_settings.append(extract_setting)
 | |
|         else:
 | |
|             raise ValueError("Data source type not support")
 | |
|         indexing_runner = IndexingRunner()
 | |
|         try:
 | |
|             response = indexing_runner.indexing_estimate(
 | |
|                 current_user.current_tenant_id,
 | |
|                 extract_settings,
 | |
|                 args["process_rule"],
 | |
|                 args["doc_form"],
 | |
|                 args["doc_language"],
 | |
|                 args["dataset_id"],
 | |
|                 args["indexing_technique"],
 | |
|             )
 | |
|         except LLMBadRequestError:
 | |
|             raise ProviderNotInitializeError(
 | |
|                 "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider."
 | |
|             )
 | |
|         except ProviderTokenNotInitError as ex:
 | |
|             raise ProviderNotInitializeError(ex.description)
 | |
|         except Exception as e:
 | |
|             raise IndexingEstimateError(str(e))
 | |
| 
 | |
|         return response, 200
 | |
| 
 | |
| 
 | |
| class DatasetRelatedAppListApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(related_app_list)
 | |
|     def get(self, dataset_id):
 | |
|         dataset_id_str = str(dataset_id)
 | |
|         dataset = DatasetService.get_dataset(dataset_id_str)
 | |
|         if dataset is None:
 | |
|             raise NotFound("Dataset not found.")
 | |
| 
 | |
|         try:
 | |
|             DatasetService.check_dataset_permission(dataset, current_user)
 | |
|         except services.errors.account.NoPermissionError as e:
 | |
|             raise Forbidden(str(e))
 | |
| 
 | |
|         app_dataset_joins = DatasetService.get_related_apps(dataset.id)
 | |
| 
 | |
|         related_apps = []
 | |
|         for app_dataset_join in app_dataset_joins:
 | |
|             app_model = app_dataset_join.app
 | |
|             if app_model:
 | |
|                 related_apps.append(app_model)
 | |
| 
 | |
|         return {"data": related_apps, "total": len(related_apps)}, 200
 | |
| 
 | |
| 
 | |
| class DatasetIndexingStatusApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self, dataset_id):
 | |
|         dataset_id = str(dataset_id)
 | |
|         documents = (
 | |
|             db.session.query(Document)
 | |
|             .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
 | |
|             .all()
 | |
|         )
 | |
|         documents_status = []
 | |
|         for document in documents:
 | |
|             completed_segments = DocumentSegment.query.filter(
 | |
|                 DocumentSegment.completed_at.isnot(None),
 | |
|                 DocumentSegment.document_id == str(document.id),
 | |
|                 DocumentSegment.status != "re_segment",
 | |
|             ).count()
 | |
|             total_segments = DocumentSegment.query.filter(
 | |
|                 DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
 | |
|             ).count()
 | |
|             document.completed_segments = completed_segments
 | |
|             document.total_segments = total_segments
 | |
|             documents_status.append(marshal(document, document_status_fields))
 | |
|         data = {"data": documents_status}
 | |
|         return data
 | |
| 
 | |
| 
 | |
| class DatasetApiKeyApi(Resource):
 | |
|     max_keys = 10
 | |
|     token_prefix = "dataset-"
 | |
|     resource_type = "dataset"
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(api_key_list)
 | |
|     def get(self):
 | |
|         keys = (
 | |
|             db.session.query(ApiToken)
 | |
|             .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
 | |
|             .all()
 | |
|         )
 | |
|         return {"items": keys}
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     @marshal_with(api_key_fields)
 | |
|     def post(self):
 | |
|         # The role of the current user in the ta table must be admin or owner
 | |
|         if not current_user.is_admin_or_owner:
 | |
|             raise Forbidden()
 | |
| 
 | |
|         current_key_count = (
 | |
|             db.session.query(ApiToken)
 | |
|             .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
 | |
|             .count()
 | |
|         )
 | |
| 
 | |
|         if current_key_count >= self.max_keys:
 | |
|             flask_restful.abort(
 | |
|                 400,
 | |
|                 message=f"Cannot create more than {self.max_keys} API keys for this resource type.",
 | |
|                 code="max_keys_exceeded",
 | |
|             )
 | |
| 
 | |
|         key = ApiToken.generate_api_key(self.token_prefix, 24)
 | |
|         api_token = ApiToken()
 | |
|         api_token.tenant_id = current_user.current_tenant_id
 | |
|         api_token.token = key
 | |
|         api_token.type = self.resource_type
 | |
|         db.session.add(api_token)
 | |
|         db.session.commit()
 | |
|         return api_token, 200
 | |
| 
 | |
| 
 | |
| class DatasetApiDeleteApi(Resource):
 | |
|     resource_type = "dataset"
 | |
| 
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def delete(self, api_key_id):
 | |
|         api_key_id = str(api_key_id)
 | |
| 
 | |
|         # The role of the current user in the ta table must be admin or owner
 | |
|         if not current_user.is_admin_or_owner:
 | |
|             raise Forbidden()
 | |
| 
 | |
|         key = (
 | |
|             db.session.query(ApiToken)
 | |
|             .filter(
 | |
|                 ApiToken.tenant_id == current_user.current_tenant_id,
 | |
|                 ApiToken.type == self.resource_type,
 | |
|                 ApiToken.id == api_key_id,
 | |
|             )
 | |
|             .first()
 | |
|         )
 | |
| 
 | |
|         if key is None:
 | |
|             flask_restful.abort(404, message="API key not found")
 | |
| 
 | |
|         db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
 | |
|         db.session.commit()
 | |
| 
 | |
|         return {"result": "success"}, 204
 | |
| 
 | |
| 
 | |
| class DatasetApiBaseUrlApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self):
 | |
|         return {"api_base_url": (dify_config.SERVICE_API_URL or request.host_url.rstrip("/")) + "/v1"}
 | |
| 
 | |
| 
 | |
| class DatasetRetrievalSettingApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self):
 | |
|         vector_type = dify_config.VECTOR_STORE
 | |
|         match vector_type:
 | |
|             case (
 | |
|                 VectorType.MILVUS
 | |
|                 | VectorType.RELYT
 | |
|                 | VectorType.PGVECTOR
 | |
|                 | VectorType.TIDB_VECTOR
 | |
|                 | VectorType.CHROMA
 | |
|                 | VectorType.TENCENT
 | |
|                 | VectorType.PGVECTO_RS
 | |
|                 | VectorType.BAIDU
 | |
|                 | VectorType.VIKINGDB
 | |
|                 | VectorType.UPSTASH
 | |
|                 | VectorType.OCEANBASE
 | |
|             ):
 | |
|                 return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
 | |
|             case (
 | |
|                 VectorType.QDRANT
 | |
|                 | VectorType.WEAVIATE
 | |
|                 | VectorType.OPENSEARCH
 | |
|                 | VectorType.ANALYTICDB
 | |
|                 | VectorType.MYSCALE
 | |
|                 | VectorType.ORACLE
 | |
|                 | VectorType.ELASTICSEARCH
 | |
|                 | VectorType.PGVECTOR
 | |
|                 | VectorType.TIDB_ON_QDRANT
 | |
|                 | VectorType.LINDORM
 | |
|                 | VectorType.COUCHBASE
 | |
|             ):
 | |
|                 return {
 | |
|                     "retrieval_method": [
 | |
|                         RetrievalMethod.SEMANTIC_SEARCH.value,
 | |
|                         RetrievalMethod.FULL_TEXT_SEARCH.value,
 | |
|                         RetrievalMethod.HYBRID_SEARCH.value,
 | |
|                     ]
 | |
|                 }
 | |
|             case _:
 | |
|                 raise ValueError(f"Unsupported vector db type {vector_type}.")
 | |
| 
 | |
| 
 | |
| class DatasetRetrievalSettingMockApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self, vector_type):
 | |
|         match vector_type:
 | |
|             case (
 | |
|                 VectorType.MILVUS
 | |
|                 | VectorType.RELYT
 | |
|                 | VectorType.TIDB_VECTOR
 | |
|                 | VectorType.CHROMA
 | |
|                 | VectorType.TENCENT
 | |
|                 | VectorType.PGVECTO_RS
 | |
|                 | VectorType.BAIDU
 | |
|                 | VectorType.VIKINGDB
 | |
|                 | VectorType.UPSTASH
 | |
|                 | VectorType.OCEANBASE
 | |
|             ):
 | |
|                 return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]}
 | |
|             case (
 | |
|                 VectorType.QDRANT
 | |
|                 | VectorType.WEAVIATE
 | |
|                 | VectorType.OPENSEARCH
 | |
|                 | VectorType.ANALYTICDB
 | |
|                 | VectorType.MYSCALE
 | |
|                 | VectorType.ORACLE
 | |
|                 | VectorType.ELASTICSEARCH
 | |
|                 | VectorType.COUCHBASE
 | |
|                 | VectorType.PGVECTOR
 | |
|                 | VectorType.LINDORM
 | |
|             ):
 | |
|                 return {
 | |
|                     "retrieval_method": [
 | |
|                         RetrievalMethod.SEMANTIC_SEARCH.value,
 | |
|                         RetrievalMethod.FULL_TEXT_SEARCH.value,
 | |
|                         RetrievalMethod.HYBRID_SEARCH.value,
 | |
|                     ]
 | |
|                 }
 | |
|             case _:
 | |
|                 raise ValueError(f"Unsupported vector db type {vector_type}.")
 | |
| 
 | |
| 
 | |
| class DatasetErrorDocs(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self, dataset_id):
 | |
|         dataset_id_str = str(dataset_id)
 | |
|         dataset = DatasetService.get_dataset(dataset_id_str)
 | |
|         if dataset is None:
 | |
|             raise NotFound("Dataset not found.")
 | |
|         results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str)
 | |
| 
 | |
|         return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200
 | |
| 
 | |
| 
 | |
| class DatasetPermissionUserListApi(Resource):
 | |
|     @setup_required
 | |
|     @login_required
 | |
|     @account_initialization_required
 | |
|     def get(self, dataset_id):
 | |
|         dataset_id_str = str(dataset_id)
 | |
|         dataset = DatasetService.get_dataset(dataset_id_str)
 | |
|         if dataset is None:
 | |
|             raise NotFound("Dataset not found.")
 | |
|         try:
 | |
|             DatasetService.check_dataset_permission(dataset, current_user)
 | |
|         except services.errors.account.NoPermissionError as e:
 | |
|             raise Forbidden(str(e))
 | |
| 
 | |
|         partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
 | |
| 
 | |
|         return {
 | |
|             "data": partial_members_list,
 | |
|         }, 200
 | |
| 
 | |
| 
 | |
| api.add_resource(DatasetListApi, "/datasets")
 | |
| api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
 | |
| api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
 | |
| api.add_resource(DatasetQueryApi, "/datasets/<uuid:dataset_id>/queries")
 | |
| api.add_resource(DatasetErrorDocs, "/datasets/<uuid:dataset_id>/error-docs")
 | |
| api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate")
 | |
| api.add_resource(DatasetRelatedAppListApi, "/datasets/<uuid:dataset_id>/related-apps")
 | |
| api.add_resource(DatasetIndexingStatusApi, "/datasets/<uuid:dataset_id>/indexing-status")
 | |
| api.add_resource(DatasetApiKeyApi, "/datasets/api-keys")
 | |
| api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/<uuid:api_key_id>")
 | |
| api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
 | |
| api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
 | |
| api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
 | |
| api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
 | 
