| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import logging | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from flask_login import current_user | 
					
						
							|  |  |  | from flask_restful import Resource, marshal, reqparse | 
					
						
							|  |  |  | from werkzeug.exceptions import Forbidden, InternalServerError, NotFound | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | import services | 
					
						
							|  |  |  | from controllers.console import api | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from controllers.console.app.error import ( | 
					
						
							|  |  |  |     CompletionRequestError, | 
					
						
							|  |  |  |     ProviderModelCurrentlyNotSupportError, | 
					
						
							|  |  |  |     ProviderNotInitializeError, | 
					
						
							|  |  |  |     ProviderQuotaExceededError, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-04-17 17:40:40 +08:00
										 |  |  | from controllers.console.datasets.error import DatasetNotInitializedError | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from controllers.console.setup import setup_required | 
					
						
							|  |  |  | from controllers.console.wraps import account_initialization_required | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from core.errors.error import ( | 
					
						
							|  |  |  |     LLMBadRequestError, | 
					
						
							|  |  |  |     ModelCurrentlyNotSupportError, | 
					
						
							|  |  |  |     ProviderTokenNotInitError, | 
					
						
							|  |  |  |     QuotaExceededError, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from core.model_runtime.errors.invoke import InvokeError | 
					
						
							| 
									
										
										
										
											2023-09-27 16:06:32 +08:00
										 |  |  | from fields.hit_testing_fields import hit_testing_record_fields | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from libs.login import login_required | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | from services.dataset_service import DatasetService | 
					
						
							|  |  |  | from services.hit_testing_service import HitTestingService | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class HitTestingApi(Resource): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def post(self, dataset_id): | 
					
						
							|  |  |  |         dataset_id_str = str(dataset_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         dataset = DatasetService.get_dataset(dataset_id_str) | 
					
						
							|  |  |  |         if dataset is None: | 
					
						
							|  |  |  |             raise NotFound("Dataset not found.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             DatasetService.check_dataset_permission(dataset, current_user) | 
					
						
							|  |  |  |         except services.errors.account.NoPermissionError as e: | 
					
						
							|  |  |  |             raise Forbidden(str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							|  |  |  |         parser.add_argument('query', type=str, location='json') | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |         parser.add_argument('retrieval_model', type=dict, required=False, location='json') | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |         HitTestingService.hit_testing_args_check(args) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             response = HitTestingService.retrieve( | 
					
						
							|  |  |  |                 dataset=dataset, | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |                 query=args['query'], | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |                 account=current_user, | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |                 retrieval_model=args['retrieval_model'], | 
					
						
							|  |  |  |                 limit=10 | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} | 
					
						
							|  |  |  |         except services.errors.index.IndexNotInitializedError: | 
					
						
							|  |  |  |             raise DatasetNotInitializedError() | 
					
						
							| 
									
										
										
										
											2023-07-17 00:14:19 +08:00
										 |  |  |         except ProviderTokenNotInitError as ex: | 
					
						
							|  |  |  |             raise ProviderNotInitializeError(ex.description) | 
					
						
							| 
									
										
										
										
											2023-05-22 17:39:28 +08:00
										 |  |  |         except QuotaExceededError: | 
					
						
							|  |  |  |             raise ProviderQuotaExceededError() | 
					
						
							|  |  |  |         except ModelCurrentlyNotSupportError: | 
					
						
							|  |  |  |             raise ProviderModelCurrentlyNotSupportError() | 
					
						
							| 
									
										
										
										
											2023-08-18 17:37:31 +08:00
										 |  |  |         except LLMBadRequestError: | 
					
						
							|  |  |  |             raise ProviderNotInitializeError( | 
					
						
							| 
									
										
										
										
											2024-02-08 14:11:10 +08:00
										 |  |  |                 "No Embedding Model or Reranking Model available. Please configure a valid provider " | 
					
						
							|  |  |  |                 "in the Settings -> Model Provider.") | 
					
						
							| 
									
										
										
										
											2024-01-03 08:57:39 +08:00
										 |  |  |         except InvokeError as e: | 
					
						
							|  |  |  |             raise CompletionRequestError(e.description) | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  |         except ValueError as e: | 
					
						
							|  |  |  |             raise ValueError(str(e)) | 
					
						
							| 
									
										
										
										
											2023-05-15 08:51:32 +08:00
										 |  |  |         except Exception as e: | 
					
						
							|  |  |  |             logging.exception("Hit testing failed.") | 
					
						
							|  |  |  |             raise InternalServerError(str(e)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | api.add_resource(HitTestingApi, '/datasets/<uuid:dataset_id>/hit-testing') |