| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from flask import request | 
					
						
							|  |  |  | from flask_restful import marshal, reqparse | 
					
						
							| 
									
										
										
										
											2024-06-11 05:21:38 +02:00
										 |  |  | from werkzeug.exceptions import NotFound | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | import services.dataset_service | 
					
						
							|  |  |  | from controllers.service_api import api | 
					
						
							| 
									
										
										
										
											2024-06-14 03:25:38 +08:00
										 |  |  | from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | from controllers.service_api.wraps import DatasetApiResource | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_runtime.entities.model_entities import ModelType | 
					
						
							|  |  |  | from core.provider_manager import ProviderManager | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | from fields.dataset_fields import dataset_detail_fields | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from libs.login import current_user | 
					
						
							| 
									
										
										
										
											2024-08-21 20:25:45 +08:00
										 |  |  | from models.dataset import Dataset, DatasetPermissionEnum | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | from services.dataset_service import DatasetService | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def _validate_name(name): | 
					
						
							|  |  |  |     if not name or len(name) < 1 or len(name) > 40: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         raise ValueError("Name must be between 1 to 40 characters.") | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |     return name | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-11 05:21:38 +02:00
										 |  |  | class DatasetListApi(DatasetApiResource): | 
					
						
							|  |  |  |     """Resource for datasets.""" | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get(self, tenant_id): | 
					
						
							| 
									
										
										
										
											2024-06-11 05:21:38 +02:00
										 |  |  |         """Resource for getting datasets.""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         page = request.args.get("page", default=1, type=int) | 
					
						
							|  |  |  |         limit = request.args.get("limit", default=20, type=int) | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  |         # provider = request.args.get("provider", default="vendor") | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         search = request.args.get("keyword", default=None, type=str) | 
					
						
							|  |  |  |         tag_ids = request.args.getlist("tag_ids") | 
					
						
							| 
									
										
										
										
											2024-04-24 15:02:29 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  |         datasets, total = DatasetService.get_datasets(page, limit, tenant_id, current_user, search, tag_ids) | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |         # check embedding setting | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         provider_manager = ProviderManager() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |         model_names = [] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         for embedding_model in embedding_models: | 
					
						
							|  |  |  |             model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |         data = marshal(datasets, dataset_detail_fields) | 
					
						
							|  |  |  |         for item in data: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             if item["indexing_technique"] == "high_quality": | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |                 item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | 
					
						
							|  |  |  |                 if item_model in model_names: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                     item["embedding_available"] = True | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |                 else: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                     item["embedding_available"] = False | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 item["embedding_available"] = True | 
					
						
							|  |  |  |         response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |         return response, 200 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def post(self, tenant_id): | 
					
						
							| 
									
										
										
										
											2024-06-11 05:21:38 +02:00
										 |  |  |         """Resource for creating datasets.""" | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "name", | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |             required=True, | 
					
						
							|  |  |  |             help="type is required. Name must be between 1 to 40 characters.", | 
					
						
							|  |  |  |             type=_validate_name, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-10-25 10:50:15 +08:00
										 |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "description", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             nullable=True, | 
					
						
							|  |  |  |             required=False, | 
					
						
							|  |  |  |             default="", | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "indexing_technique", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             location="json", | 
					
						
							|  |  |  |             choices=Dataset.INDEXING_TECHNIQUE_LIST, | 
					
						
							|  |  |  |             help="Invalid indexing technique.", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "permission", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             location="json", | 
					
						
							|  |  |  |             choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), | 
					
						
							|  |  |  |             help="Invalid permission.", | 
					
						
							|  |  |  |             required=False, | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "external_knowledge_api_id", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             nullable=True, | 
					
						
							|  |  |  |             required=False, | 
					
						
							|  |  |  |             default="_validate_name", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "provider", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             nullable=True, | 
					
						
							|  |  |  |             required=False, | 
					
						
							|  |  |  |             default="vendor", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "external_knowledge_id", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             nullable=True, | 
					
						
							|  |  |  |             required=False, | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             dataset = DatasetService.create_empty_dataset( | 
					
						
							|  |  |  |                 tenant_id=tenant_id, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 name=args["name"], | 
					
						
							| 
									
										
										
										
											2024-10-25 10:50:15 +08:00
										 |  |  |                 description=args["description"], | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 indexing_technique=args["indexing_technique"], | 
					
						
							| 
									
										
										
										
											2024-08-21 20:25:45 +08:00
										 |  |  |                 account=current_user, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 permission=args["permission"], | 
					
						
							| 
									
										
										
										
											2024-09-30 15:38:43 +08:00
										 |  |  |                 provider=args["provider"], | 
					
						
							|  |  |  |                 external_knowledge_api_id=args["external_knowledge_api_id"], | 
					
						
							|  |  |  |                 external_knowledge_id=args["external_knowledge_id"], | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         except services.errors.dataset.DatasetNameDuplicateError: | 
					
						
							|  |  |  |             raise DatasetNameDuplicateError() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return marshal(dataset, dataset_detail_fields), 200 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-11 05:21:38 +02:00
										 |  |  | class DatasetApi(DatasetApiResource): | 
					
						
							|  |  |  |     """Resource for dataset.""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def delete(self, _, dataset_id): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Deletes a dataset given its ID. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             dataset_id (UUID): The ID of the dataset to be deleted. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             dict: A dictionary with a key 'result' and a value 'success' | 
					
						
							| 
									
										
										
										
											2024-06-11 05:21:38 +02:00
										 |  |  |                   if the dataset was successfully deleted. Omitted in HTTP response. | 
					
						
							|  |  |  |             int: HTTP status code 204 indicating that the operation was successful. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Raises: | 
					
						
							|  |  |  |             NotFound: If the dataset with the given ID does not exist. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         dataset_id_str = str(dataset_id) | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-14 03:25:38 +08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             if DatasetService.delete_dataset(dataset_id_str, current_user): | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 return {"result": "success"}, 204 | 
					
						
							| 
									
										
										
										
											2024-06-14 03:25:38 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 raise NotFound("Dataset not found.") | 
					
						
							|  |  |  |         except services.errors.dataset.DatasetInUseError: | 
					
						
							|  |  |  |             raise DatasetInUseError() | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | api.add_resource(DatasetListApi, "/datasets") | 
					
						
							|  |  |  | api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") |