| 
									
										
										
										
											2023-11-21 15:38:27 +08:00
										 |  |  | import logging | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from flask_login import current_user | 
					
						
							|  |  |  | from flask_restful import Resource, reqparse | 
					
						
							|  |  |  | from werkzeug.exceptions import Forbidden | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  | from controllers.console import api | 
					
						
							|  |  |  | from controllers.console.setup import setup_required | 
					
						
							|  |  |  | from controllers.console.wraps import account_initialization_required | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_runtime.entities.model_entities import ModelType | 
					
						
							|  |  |  | from core.model_runtime.errors.validate import CredentialsValidateFailedError | 
					
						
							|  |  |  | from core.model_runtime.utils.encoders import jsonable_encoder | 
					
						
							|  |  |  | from libs.login import login_required | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | from services.model_load_balancing_service import ModelLoadBalancingService | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from services.model_provider_service import ModelProviderService | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DefaultModelApi(Resource): | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def get(self): | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "model_type", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             required=True, | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |             choices=[mt.value for mt in ModelType], | 
					
						
							|  |  |  |             location="args", | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         model_provider_service = ModelProviderService() | 
					
						
							|  |  |  |         default_model_entity = model_provider_service.get_default_model_of_model_type( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             tenant_id=tenant_id, model_type=args["model_type"] | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return jsonable_encoder({"data": default_model_entity}) | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def post(self): | 
					
						
							| 
									
										
										
										
											2024-06-14 07:34:25 -05:00
										 |  |  |         if not current_user.is_admin_or_owner: | 
					
						
							|  |  |  |             raise Forbidden() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_provider_service = ModelProviderService() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         model_settings = args["model_settings"] | 
					
						
							| 
									
										
										
										
											2023-11-17 22:13:37 +08:00
										 |  |  |         for model_setting in model_settings: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: | 
					
						
							|  |  |  |                 raise ValueError("invalid model type") | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             if "provider" not in model_setting: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             if "model" not in model_setting: | 
					
						
							|  |  |  |                 raise ValueError("invalid model") | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-11-21 15:38:27 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 model_provider_service.update_default_model_of_model_type( | 
					
						
							|  |  |  |                     tenant_id=tenant_id, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                     model_type=model_setting["model_type"], | 
					
						
							|  |  |  |                     provider=model_setting["provider"], | 
					
						
							|  |  |  |                     model=model_setting["model"], | 
					
						
							| 
									
										
										
										
											2023-11-21 15:38:27 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  |             except Exception: | 
					
						
							|  |  |  |                 logging.warning(f"{model_setting['model_type']} save error") | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return {"result": "success"} | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | class ModelProviderModelApi(Resource): | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def get(self, provider): | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_provider_service = ModelProviderService() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return jsonable_encoder({"data": models}) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def post(self, provider: str): | 
					
						
							| 
									
										
										
										
											2024-06-14 07:34:25 -05:00
										 |  |  |         if not current_user.is_admin_or_owner: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             raise Forbidden() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument("model", type=str, required=True, nullable=False, location="json") | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "model_type", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             required=True, | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |             choices=[mt.value for mt in ModelType], | 
					
						
							|  |  |  |             location="json", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") | 
					
						
							|  |  |  |         parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") | 
					
						
							|  |  |  |         parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         model_load_balancing_service = ModelLoadBalancingService() | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         if ( | 
					
						
							|  |  |  |             "load_balancing" in args | 
					
						
							|  |  |  |             and args["load_balancing"] | 
					
						
							|  |  |  |             and "enabled" in args["load_balancing"] | 
					
						
							|  |  |  |             and args["load_balancing"]["enabled"] | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             if "configs" not in args["load_balancing"]: | 
					
						
							|  |  |  |                 raise ValueError("invalid load balancing configs") | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # save load balancing configs | 
					
						
							|  |  |  |             model_load_balancing_service.update_load_balancing_configs( | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                 provider=provider, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 model=args["model"], | 
					
						
							|  |  |  |                 model_type=args["model_type"], | 
					
						
							|  |  |  |                 configs=args["load_balancing"]["configs"], | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # enable load balancing | 
					
						
							|  |  |  |             model_load_balancing_service.enable_model_load_balancing( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             # disable load balancing | 
					
						
							|  |  |  |             model_load_balancing_service.disable_model_load_balancing( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             if args.get("config_from", "") != "predefined-model": | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 model_provider_service = ModelProviderService() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     model_provider_service.save_model_credentials( | 
					
						
							|  |  |  |                         tenant_id=tenant_id, | 
					
						
							|  |  |  |                         provider=provider, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                         model=args["model"], | 
					
						
							|  |  |  |                         model_type=args["model_type"], | 
					
						
							|  |  |  |                         credentials=args["credentials"], | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  |                 except CredentialsValidateFailedError as ex: | 
					
						
							| 
									
										
										
										
											2024-08-26 14:46:29 +08:00
										 |  |  |                     logging.exception(f"save model credentials error: {ex}") | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                     raise ValueError(str(ex)) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return {"result": "success"}, 200 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def delete(self, provider: str): | 
					
						
							| 
									
										
										
										
											2024-06-14 07:34:25 -05:00
										 |  |  |         if not current_user.is_admin_or_owner: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             raise Forbidden() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument("model", type=str, required=True, nullable=False, location="json") | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "model_type", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             required=True, | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |             choices=[mt.value for mt in ModelType], | 
					
						
							|  |  |  |             location="json", | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_provider_service = ModelProviderService() | 
					
						
							|  |  |  |         model_provider_service.remove_model_credentials( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return {"result": "success"}, 204 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ModelProviderModelCredentialApi(Resource): | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument("model", type=str, required=True, nullable=False, location="args") | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "model_type", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             required=True, | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |             choices=[mt.value for mt in ModelType], | 
					
						
							|  |  |  |             location="args", | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_provider_service = ModelProviderService() | 
					
						
							|  |  |  |         credentials = model_provider_service.get_model_credentials( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         model_load_balancing_service = ModelLoadBalancingService() | 
					
						
							|  |  |  |         is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         return { | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             "credentials": credentials, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | class ModelProviderModelEnableApi(Resource): | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def patch(self, provider: str): | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument("model", type=str, required=True, nullable=False, location="json") | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "model_type", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             required=True, | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |             choices=[mt.value for mt in ModelType], | 
					
						
							|  |  |  |             location="json", | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_provider_service = ModelProviderService() | 
					
						
							|  |  |  |         model_provider_service.enable_model( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return {"result": "success"} | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ModelProviderModelDisableApi(Resource): | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def patch(self, provider: str): | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument("model", type=str, required=True, nullable=False, location="json") | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "model_type", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             required=True, | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |             choices=[mt.value for mt in ModelType], | 
					
						
							|  |  |  |             location="json", | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_provider_service = ModelProviderService() | 
					
						
							|  |  |  |         model_provider_service.disable_model( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return {"result": "success"} | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | class ModelProviderModelValidateApi(Resource): | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def post(self, provider: str): | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument("model", type=str, required=True, nullable=False, location="json") | 
					
						
							|  |  |  |         parser.add_argument( | 
					
						
							|  |  |  |             "model_type", | 
					
						
							|  |  |  |             type=str, | 
					
						
							|  |  |  |             required=True, | 
					
						
							|  |  |  |             nullable=False, | 
					
						
							|  |  |  |             choices=[mt.value for mt in ModelType], | 
					
						
							|  |  |  |             location="json", | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |         parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_provider_service = ModelProviderService() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         result = True | 
					
						
							|  |  |  |         error = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             model_provider_service.model_credentials_validate( | 
					
						
							|  |  |  |                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                 provider=provider, | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |                 model=args["model"], | 
					
						
							|  |  |  |                 model_type=args["model_type"], | 
					
						
							|  |  |  |                 credentials=args["credentials"], | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         except CredentialsValidateFailedError as ex: | 
					
						
							|  |  |  |             result = False | 
					
						
							|  |  |  |             error = str(ex) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         response = {"result": "success" if result else "error"} | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if not result: | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             response["error"] = error | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return response | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ModelProviderModelParameterRuleApi(Resource): | 
					
						
							|  |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def get(self, provider: str): | 
					
						
							|  |  |  |         parser = reqparse.RequestParser() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         parser.add_argument("model", type=str, required=True, nullable=False, location="args") | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_provider_service = ModelProviderService() | 
					
						
							|  |  |  |         parameter_rules = model_provider_service.get_model_parameter_rules( | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |             tenant_id=tenant_id, provider=provider, model=args["model"] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         return jsonable_encoder({"data": parameter_rules}) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ModelProviderAvailableModelApi(Resource): | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  |     @setup_required | 
					
						
							|  |  |  |     @login_required | 
					
						
							|  |  |  |     @account_initialization_required | 
					
						
							|  |  |  |     def get(self, model_type): | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         tenant_id = current_user.current_tenant_id | 
					
						
							| 
									
										
										
										
											2023-08-12 00:57:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         model_provider_service = ModelProviderService() | 
					
						
							| 
									
										
										
										
											2024-08-26 15:29:10 +08:00
										 |  |  |         models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return jsonable_encoder({"data": models}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models") | 
					
						
							|  |  |  | api.add_resource( | 
					
						
							|  |  |  |     ModelProviderModelEnableApi, | 
					
						
							|  |  |  |     "/workspaces/current/model-providers/<string:provider>/models/enable", | 
					
						
							|  |  |  |     endpoint="model-provider-model-enable", | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | api.add_resource( | 
					
						
							|  |  |  |     ModelProviderModelDisableApi, | 
					
						
							|  |  |  |     "/workspaces/current/model-providers/<string:provider>/models/disable", | 
					
						
							|  |  |  |     endpoint="model-provider-model-disable", | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | api.add_resource( | 
					
						
							|  |  |  |     ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | api.add_resource( | 
					
						
							|  |  |  |     ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | api.add_resource( | 
					
						
							|  |  |  |     ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules" | 
					
						
							|  |  |  | ) | 
					
						
							|  |  |  | api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") | 
					
						
							|  |  |  | api.add_resource(DefaultModelApi, "/workspaces/current/default-model") |