from typing import Any, Literal, cast from flask import request from flask_restx import marshal from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import Forbidden, NotFound import services from controllers.common.schema import register_schema_models from controllers.console.wraps import edit_permission_required from controllers.service_api import service_api_ns from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, validate_dataset_token, ) from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import build_dataset_tag_fields from libs.login import current_user from models.account import Account from models.dataset import DatasetPermissionEnum from models.provider_ids import ModelProviderID from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.tag_service import TagService class DatasetCreatePayload(BaseModel): name: str = Field(..., min_length=1, max_length=40) description: str = Field(default="", description="Dataset description (max 400 chars)", max_length=400) indexing_technique: Literal["high_quality", "economy"] | None = None permission: DatasetPermissionEnum | None = DatasetPermissionEnum.ONLY_ME external_knowledge_api_id: str | None = None provider: str = "vendor" external_knowledge_id: str | None = None retrieval_model: RetrievalModel | None = None embedding_model: str | None = None embedding_model_provider: str | None = None class DatasetUpdatePayload(BaseModel): name: str | None = Field(default=None, min_length=1, max_length=40) description: str | None = Field(default=None, description="Dataset description (max 400 chars)", max_length=400) indexing_technique: Literal["high_quality", "economy"] | None = None permission: DatasetPermissionEnum | None = None embedding_model: str | None = None embedding_model_provider: str | None = None retrieval_model: RetrievalModel | None = None partial_member_list: list[str] | None = None external_retrieval_model: dict[str, Any] | None = None external_knowledge_id: str | None = None external_knowledge_api_id: str | None = None class TagNamePayload(BaseModel): name: str = Field(..., min_length=1, max_length=50) class TagCreatePayload(TagNamePayload): pass class TagUpdatePayload(TagNamePayload): tag_id: str class TagDeletePayload(BaseModel): tag_id: str class TagBindingPayload(BaseModel): tag_ids: list[str] target_id: str @field_validator("tag_ids") @classmethod def validate_tag_ids(cls, value: list[str]) -> list[str]: if not value: raise ValueError("Tag IDs is required.") return value class TagUnbindingPayload(BaseModel): tag_id: str target_id: str register_schema_models( service_api_ns, DatasetCreatePayload, DatasetUpdatePayload, TagCreatePayload, TagUpdatePayload, TagDeletePayload, TagBindingPayload, TagUnbindingPayload, ) @service_api_ns.route("/datasets") class DatasetListApi(DatasetApiResource): """Resource for datasets.""" @service_api_ns.doc("list_datasets") @service_api_ns.doc(description="List all datasets") @service_api_ns.doc( responses={ 200: "Datasets retrieved successfully", 401: "Unauthorized - invalid API token", } ) def get(self, tenant_id): """Resource for getting datasets.""" page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) # provider = request.args.get("provider", default="vendor") search = request.args.get("keyword", default=None, type=str) tag_ids = request.args.getlist("tag_ids") include_all = request.args.get("include_all", default="false").lower() == "true" datasets, total = DatasetService.get_datasets( page, limit, tenant_id, current_user, search, tag_ids, include_all ) # check embedding setting provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") data = marshal(datasets, dataset_detail_fields) for item in data: if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: item["embedding_model_provider"] = str(ModelProviderID(item["embedding_model_provider"])) item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: item["embedding_available"] = True else: item["embedding_available"] = False else: item["embedding_available"] = True response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 @service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__]) @service_api_ns.doc("create_dataset") @service_api_ns.doc(description="Create a new dataset") @service_api_ns.doc( responses={ 200: "Dataset created successfully", 401: "Unauthorized - invalid API token", 400: "Bad request - invalid parameters", } ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id): """Resource for creating datasets.""" payload = DatasetCreatePayload.model_validate(service_api_ns.payload or {}) embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model) retrieval_model = payload.retrieval_model if ( retrieval_model and retrieval_model.reranking_model and retrieval_model.reranking_model.reranking_provider_name and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( tenant_id, retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) try: assert isinstance(current_user, Account) dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, name=payload.name, description=payload.description, indexing_technique=payload.indexing_technique, account=current_user, permission=str(payload.permission) if payload.permission else None, provider=payload.provider, external_knowledge_api_id=payload.external_knowledge_api_id, external_knowledge_id=payload.external_knowledge_id, embedding_model_provider=payload.embedding_model_provider, embedding_model_name=payload.embedding_model, retrieval_model=payload.retrieval_model, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() return marshal(dataset, dataset_detail_fields), 200 @service_api_ns.route("/datasets/") class DatasetApi(DatasetApiResource): """Resource for dataset.""" @service_api_ns.doc("get_dataset") @service_api_ns.doc(description="Get a specific dataset by ID") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc( responses={ 200: "Dataset retrieved successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", 404: "Dataset not found", } ) def get(self, _, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") if data.get("indexing_technique") == "high_quality": item_model = f"{data.get('embedding_model')}:{data.get('embedding_model_provider')}" if item_model in model_names: data["embedding_available"] = True else: data["embedding_available"] = False else: data["embedding_available"] = True # force update search method to keyword_search if indexing_technique is economic retrieval_model_dict = data.get("retrieval_model_dict") if retrieval_model_dict: retrieval_model_dict["search_method"] = "keyword_search" if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) data.update({"partial_member_list": part_users_list}) return data, 200 @service_api_ns.expect(service_api_ns.models[DatasetUpdatePayload.__name__]) @service_api_ns.doc("update_dataset") @service_api_ns.doc(description="Update an existing dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc( responses={ 200: "Dataset updated successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", 404: "Dataset not found", } ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, _, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") payload_dict = service_api_ns.payload or {} payload = DatasetUpdatePayload.model_validate(payload_dict) update_data = payload.model_dump(exclude_unset=True) if payload.permission is not None: update_data["permission"] = str(payload.permission) if payload.retrieval_model is not None: update_data["retrieval_model"] = payload.retrieval_model.model_dump() # check embedding model setting embedding_model_provider = payload.embedding_model_provider embedding_model = payload.embedding_model if payload.indexing_technique == "high_quality" or embedding_model_provider: if embedding_model_provider and embedding_model: DatasetService.check_embedding_model_setting( dataset.tenant_id, embedding_model_provider, embedding_model ) retrieval_model = payload.retrieval_model if ( retrieval_model and retrieval_model.reranking_model and retrieval_model.reranking_model.reranking_provider_name and retrieval_model.reranking_model.reranking_model_name ): DatasetService.check_reranking_model_setting( dataset.tenant_id, retrieval_model.reranking_model.reranking_provider_name, retrieval_model.reranking_model.reranking_model_name, ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( current_user, dataset, str(payload.permission) if payload.permission else None, payload.partial_member_list, ) dataset = DatasetService.update_dataset(dataset_id_str, update_data, current_user) if dataset is None: raise NotFound("Dataset not found.") result_data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) assert isinstance(current_user, Account) tenant_id = current_user.current_tenant_id if payload.partial_member_list and payload.permission == DatasetPermissionEnum.PARTIAL_TEAM: DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id_str, payload.partial_member_list) # clear partial member list when permission is only_me or all_team_members elif payload.permission in {DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM}: DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) result_data.update({"partial_member_list": partial_member_list}) return result_data, 200 @service_api_ns.doc("delete_dataset") @service_api_ns.doc(description="Delete a dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc( responses={ 204: "Dataset deleted successfully", 401: "Unauthorized - invalid API token", 404: "Dataset not found", 409: "Conflict - dataset is in use", } ) @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, _, dataset_id): """ Deletes a dataset given its ID. Args: _: ignore dataset_id (UUID): The ID of the dataset to be deleted. Returns: dict: A dictionary with a key 'result' and a value 'success' if the dataset was successfully deleted. Omitted in HTTP response. int: HTTP status code 204 indicating that the operation was successful. Raises: NotFound: If the dataset with the given ID does not exist. """ dataset_id_str = str(dataset_id) try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) return 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() @service_api_ns.route("/datasets//documents/status/") class DocumentStatusApi(DatasetApiResource): """Resource for batch document status operations.""" @service_api_ns.doc("update_document_status") @service_api_ns.doc(description="Batch update document status") @service_api_ns.doc( params={ "dataset_id": "Dataset ID", "action": "Action to perform: 'enable', 'disable', 'archive', or 'un_archive'", } ) @service_api_ns.doc( responses={ 200: "Document status updated successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", 404: "Dataset not found", 400: "Bad request - invalid action", } ) def patch(self, tenant_id, dataset_id, action: Literal["enable", "disable", "archive", "un_archive"]): """ Batch update document status. Args: tenant_id: tenant id dataset_id: dataset id action: action to perform (Literal["enable", "disable", "archive", "un_archive"]) Returns: dict: A dictionary with a key 'result' and a value 'success' int: HTTP status code 200 indicating that the operation was successful. Raises: NotFound: If the dataset with the given ID does not exist. Forbidden: If the user does not have permission. InvalidActionError: If the action is invalid or cannot be performed. """ dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) if dataset is None: raise NotFound("Dataset not found.") # Check user's permission try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) # Check dataset model setting DatasetService.check_dataset_model_setting(dataset) # Get document IDs from request body data = request.get_json() document_ids = data.get("document_ids", []) try: DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) except services.errors.document.DocumentIndexingError as e: raise InvalidActionError(str(e)) except ValueError as e: raise InvalidActionError(str(e)) return {"result": "success"}, 200 @service_api_ns.route("/datasets/tags") class DatasetTagsApi(DatasetApiResource): @service_api_ns.doc("list_dataset_tags") @service_api_ns.doc(description="Get all knowledge type tags") @service_api_ns.doc( responses={ 200: "Tags retrieved successfully", 401: "Unauthorized - invalid API token", } ) @validate_dataset_token @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) def get(self, _, dataset_id): """Get all knowledge type tags.""" assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None tags = TagService.get_tags("knowledge", cid) return tags, 200 @service_api_ns.expect(service_api_ns.models[TagCreatePayload.__name__]) @service_api_ns.doc("create_dataset_tag") @service_api_ns.doc(description="Add a knowledge type tag") @service_api_ns.doc( responses={ 200: "Tag created successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", } ) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def post(self, _, dataset_id): """Add a knowledge type tag.""" assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() payload = TagCreatePayload.model_validate(service_api_ns.payload or {}) tag = TagService.save_tags({"name": payload.name, "type": "knowledge"}) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 @service_api_ns.expect(service_api_ns.models[TagUpdatePayload.__name__]) @service_api_ns.doc("update_dataset_tag") @service_api_ns.doc(description="Update a knowledge type tag") @service_api_ns.doc( responses={ 200: "Tag updated successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", } ) @service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns)) @validate_dataset_token def patch(self, _, dataset_id): assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() payload = TagUpdatePayload.model_validate(service_api_ns.payload or {}) params = {"name": payload.name, "type": "knowledge"} tag_id = payload.tag_id tag = TagService.update_tags(params, tag_id) binding_count = TagService.get_tag_binding_count(tag_id) response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} return response, 200 @service_api_ns.expect(service_api_ns.models[TagDeletePayload.__name__]) @service_api_ns.doc("delete_dataset_tag") @service_api_ns.doc(description="Delete a knowledge type tag") @service_api_ns.doc( responses={ 204: "Tag deleted successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", } ) @validate_dataset_token @edit_permission_required def delete(self, _, dataset_id): """Delete a knowledge type tag.""" payload = TagDeletePayload.model_validate(service_api_ns.payload or {}) TagService.delete_tag(payload.tag_id) return 204 @service_api_ns.route("/datasets/tags/binding") class DatasetTagBindingApi(DatasetApiResource): @service_api_ns.expect(service_api_ns.models[TagBindingPayload.__name__]) @service_api_ns.doc("bind_dataset_tags") @service_api_ns.doc(description="Bind tags to a dataset") @service_api_ns.doc( responses={ 204: "Tags bound successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", } ) @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() payload = TagBindingPayload.model_validate(service_api_ns.payload or {}) TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"}) return 204 @service_api_ns.route("/datasets/tags/unbinding") class DatasetTagUnbindingApi(DatasetApiResource): @service_api_ns.expect(service_api_ns.models[TagUnbindingPayload.__name__]) @service_api_ns.doc("unbind_dataset_tag") @service_api_ns.doc(description="Unbind a tag from a dataset") @service_api_ns.doc( responses={ 204: "Tag unbound successfully", 401: "Unauthorized - invalid API token", 403: "Forbidden - insufficient permissions", } ) @validate_dataset_token def post(self, _, dataset_id): # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator assert isinstance(current_user, Account) if not (current_user.has_edit_permission or current_user.is_dataset_editor): raise Forbidden() payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {}) TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"}) return 204 @service_api_ns.route("/datasets//tags") class DatasetTagsBindingStatusApi(DatasetApiResource): @service_api_ns.doc("get_dataset_tags_binding_status") @service_api_ns.doc(description="Get tags bound to a specific dataset") @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) @service_api_ns.doc( responses={ 200: "Tags retrieved successfully", 401: "Unauthorized - invalid API token", } ) @validate_dataset_token def get(self, _, *args, **kwargs): """Get all knowledge type tags.""" dataset_id = kwargs.get("dataset_id") assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] response = {"data": tags_list, "total": len(tags)} return response, 200