| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | import datetime | 
					
						
							|  |  |  | import json | 
					
						
							| 
									
										
										
										
											2024-01-04 10:53:50 +08:00
										 |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | from collections import defaultdict | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  | from collections.abc import Iterator | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from json import JSONDecodeError | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  | from typing import Optional | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-14 01:05:37 +08:00
										 |  |  | from pydantic import BaseModel, ConfigDict | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-07 17:30:56 +08:00
										 |  |  | from constants import HIDDEN_VALUE | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | from core.entities.provider_entities import ( | 
					
						
							|  |  |  |     CustomConfiguration, | 
					
						
							|  |  |  |     ModelSettings, | 
					
						
							|  |  |  |     SystemConfiguration, | 
					
						
							|  |  |  |     SystemConfigurationStatus, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.helper import encrypter | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  | from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from core.model_runtime.entities.model_entities import FetchFrom, ModelType | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from core.model_runtime.entities.provider_entities import ( | 
					
						
							|  |  |  |     ConfigurateMethod, | 
					
						
							|  |  |  |     CredentialFormSchema, | 
					
						
							|  |  |  |     FormType, | 
					
						
							|  |  |  |     ProviderEntity, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_runtime.model_providers import model_provider_factory | 
					
						
							|  |  |  | from core.model_runtime.model_providers.__base.ai_model import AIModel | 
					
						
							|  |  |  | from core.model_runtime.model_providers.__base.model_provider import ModelProvider | 
					
						
							|  |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | from models.provider import ( | 
					
						
							|  |  |  |     LoadBalancingModelConfig, | 
					
						
							|  |  |  |     Provider, | 
					
						
							|  |  |  |     ProviderModel, | 
					
						
							|  |  |  |     ProviderModelSetting, | 
					
						
							|  |  |  |     ProviderType, | 
					
						
							|  |  |  |     TenantPreferredModelProvider, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 10:53:50 +08:00
										 |  |  | logger = logging.getLogger(__name__) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  | original_provider_configurate_methods = {} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | class ProviderConfiguration(BaseModel): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Model class for provider configuration. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |     tenant_id: str | 
					
						
							|  |  |  |     provider: ProviderEntity | 
					
						
							|  |  |  |     preferred_provider_type: ProviderType | 
					
						
							|  |  |  |     using_provider_type: ProviderType | 
					
						
							|  |  |  |     system_configuration: SystemConfiguration | 
					
						
							|  |  |  |     custom_configuration: CustomConfiguration | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |     model_settings: list[ModelSettings] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-14 01:05:37 +08:00
										 |  |  |     # pydantic configs | 
					
						
							|  |  |  |     model_config = ConfigDict(protected_namespaces=()) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |     def __init__(self, **data): | 
					
						
							|  |  |  |         super().__init__(**data) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if self.provider.provider not in original_provider_configurate_methods: | 
					
						
							|  |  |  |             original_provider_configurate_methods[self.provider.provider] = [] | 
					
						
							|  |  |  |             for configurate_method in self.provider.configurate_methods: | 
					
						
							|  |  |  |                 original_provider_configurate_methods[self.provider.provider].append(configurate_method) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if ( | 
					
						
							|  |  |  |                 any( | 
					
						
							|  |  |  |                     len(quota_configuration.restrict_models) > 0 | 
					
						
							|  |  |  |                     for quota_configuration in self.system_configuration.quota_configurations | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods | 
					
						
							|  |  |  |             ): | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |                 self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |     def get_current_credentials(self, model_type: ModelType, model: str) -> Optional[dict]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get current credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         if self.model_settings: | 
					
						
							|  |  |  |             # check if model is disabled by admin | 
					
						
							|  |  |  |             for model_setting in self.model_settings: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 if model_setting.model_type == model_type and model_setting.model == model: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                     if not model_setting.enabled: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         raise ValueError(f"Model {model} is disabled.") | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         if self.using_provider_type == ProviderType.SYSTEM: | 
					
						
							| 
									
										
										
										
											2024-01-11 10:49:35 +08:00
										 |  |  |             restrict_models = [] | 
					
						
							|  |  |  |             for quota_configuration in self.system_configuration.quota_configurations: | 
					
						
							|  |  |  |                 if self.system_configuration.current_quota_type != quota_configuration.quota_type: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 restrict_models = quota_configuration.restrict_models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             copy_credentials = self.system_configuration.credentials.copy() | 
					
						
							|  |  |  |             if restrict_models: | 
					
						
							|  |  |  |                 for restrict_model in restrict_models: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     if ( | 
					
						
							|  |  |  |                         restrict_model.model_type == model_type | 
					
						
							|  |  |  |                         and restrict_model.model == model | 
					
						
							|  |  |  |                         and restrict_model.base_model_name | 
					
						
							|  |  |  |                     ): | 
					
						
							|  |  |  |                         copy_credentials["base_model_name"] = restrict_model.base_model_name | 
					
						
							| 
									
										
										
										
											2024-01-11 10:49:35 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return copy_credentials | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             credentials = None | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             if self.custom_configuration.models: | 
					
						
							|  |  |  |                 for model_configuration in self.custom_configuration.models: | 
					
						
							|  |  |  |                     if model_configuration.model_type == model_type and model_configuration.model == model: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                         credentials = model_configuration.credentials | 
					
						
							|  |  |  |                         break | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-28 15:54:26 +08:00
										 |  |  |             if not credentials and self.custom_configuration.provider: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 credentials = self.custom_configuration.provider.credentials | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return credentials | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_system_configuration_status(self) -> SystemConfigurationStatus: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get system configuration status. | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if self.system_configuration.enabled is False: | 
					
						
							|  |  |  |             return SystemConfigurationStatus.UNSUPPORTED | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         current_quota_type = self.system_configuration.current_quota_type | 
					
						
							|  |  |  |         current_quota_configuration = next( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             (q for q in self.system_configuration.quota_configurations if q.quota_type == current_quota_type), None | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         return ( | 
					
						
							|  |  |  |             SystemConfigurationStatus.ACTIVE | 
					
						
							|  |  |  |             if current_quota_configuration.is_valid | 
					
						
							|  |  |  |             else SystemConfigurationStatus.QUOTA_EXCEEDED | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def is_custom_configuration_available(self) -> bool: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Check custom configuration available. | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         return self.custom_configuration.provider is not None or len(self.custom_configuration.models) > 0 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_custom_credentials(self, obfuscated: bool = False) -> Optional[dict]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get custom credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param obfuscated: obfuscated secret data in credentials | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if self.custom_configuration.provider is None: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         credentials = self.custom_configuration.provider.credentials | 
					
						
							|  |  |  |         if not obfuscated: | 
					
						
							|  |  |  |             return credentials | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Obfuscate credentials | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         return self.obfuscated_credentials( | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             credentials=credentials, | 
					
						
							|  |  |  |             credential_form_schemas=self.provider.provider_credential_schema.credential_form_schemas | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if self.provider.provider_credential_schema | 
					
						
							|  |  |  |             else [], | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  |     def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Validate custom credentials. | 
					
						
							|  |  |  |         :param credentials: provider credentials | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # get provider | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_record = ( | 
					
						
							|  |  |  |             db.session.query(Provider) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 Provider.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 Provider.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 Provider.provider_type == ProviderType.CUSTOM.value, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Get provider credential secret variables | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         provider_credential_secret_variables = self.extract_secret_variables( | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             self.provider.provider_credential_schema.credential_form_schemas | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if self.provider.provider_credential_schema | 
					
						
							|  |  |  |             else [] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if provider_record: | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-01-24 17:19:25 +08:00
										 |  |  |                 # fix origin data | 
					
						
							|  |  |  |                 if provider_record.encrypted_config: | 
					
						
							|  |  |  |                     if not provider_record.encrypted_config.startswith("{"): | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         original_credentials = {"openai_api_key": provider_record.encrypted_config} | 
					
						
							| 
									
										
										
										
											2024-01-24 17:19:25 +08:00
										 |  |  |                     else: | 
					
						
							|  |  |  |                         original_credentials = json.loads(provider_record.encrypted_config) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     original_credentials = {} | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             except JSONDecodeError: | 
					
						
							|  |  |  |                 original_credentials = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # encrypt credentials | 
					
						
							|  |  |  |             for key, value in credentials.items(): | 
					
						
							|  |  |  |                 if key in provider_credential_secret_variables: | 
					
						
							|  |  |  |                     # if send [__HIDDEN__] in secret input, it will be same as original value | 
					
						
							| 
									
										
										
										
											2024-08-07 17:30:56 +08:00
										 |  |  |                     if value == HIDDEN_VALUE and key in original_credentials: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                         credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 16:48:38 +08:00
										 |  |  |         credentials = model_provider_factory.provider_credentials_validate( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             provider=self.provider.provider, credentials=credentials | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for key, value in credentials.items(): | 
					
						
							|  |  |  |             if key in provider_credential_secret_variables: | 
					
						
							|  |  |  |                 credentials[key] = encrypter.encrypt_token(self.tenant_id, value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return provider_record, credentials | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_or_update_custom_credentials(self, credentials: dict) -> None: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Add or update custom provider credentials. | 
					
						
							|  |  |  |         :param credentials: | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # validate custom provider config | 
					
						
							|  |  |  |         provider_record, credentials = self.custom_credentials_validate(credentials) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # save provider | 
					
						
							|  |  |  |         # Note: Do not switch the preferred provider, which allows users to use quotas first | 
					
						
							|  |  |  |         if provider_record: | 
					
						
							|  |  |  |             provider_record.encrypted_config = json.dumps(credentials) | 
					
						
							|  |  |  |             provider_record.is_valid = True | 
					
						
							| 
									
										
										
										
											2024-04-12 16:22:24 +08:00
										 |  |  |             provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             db.session.commit() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             provider_record = Provider( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 provider_name=self.provider.provider, | 
					
						
							|  |  |  |                 provider_type=ProviderType.CUSTOM.value, | 
					
						
							|  |  |  |                 encrypted_config=json.dumps(credentials), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 is_valid=True, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             db.session.add(provider_record) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |         provider_model_credentials_cache = ProviderCredentialsCache( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             tenant_id=self.tenant_id, identity_id=provider_record.id, cache_type=ProviderCredentialsCacheType.PROVIDER | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         provider_model_credentials_cache.delete() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         self.switch_preferred_provider_type(ProviderType.CUSTOM) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def delete_custom_credentials(self) -> None: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Delete custom provider credentials. | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # get provider | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_record = ( | 
					
						
							|  |  |  |             db.session.query(Provider) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 Provider.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 Provider.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 Provider.provider_type == ProviderType.CUSTOM.value, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # delete provider | 
					
						
							|  |  |  |         if provider_record: | 
					
						
							|  |  |  |             self.switch_preferred_provider_type(ProviderType.SYSTEM) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             db.session.delete(provider_record) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             provider_model_credentials_cache = ProviderCredentialsCache( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 identity_id=provider_record.id, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 cache_type=ProviderCredentialsCacheType.PROVIDER, | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             provider_model_credentials_cache.delete() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def get_custom_model_credentials( | 
					
						
							|  |  |  |         self, model_type: ModelType, model: str, obfuscated: bool = False | 
					
						
							|  |  |  |     ) -> Optional[dict]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get custom model credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :param obfuscated: obfuscated secret data in credentials | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if not self.custom_configuration.models: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for model_configuration in self.custom_configuration.models: | 
					
						
							|  |  |  |             if model_configuration.model_type == model_type and model_configuration.model == model: | 
					
						
							|  |  |  |                 credentials = model_configuration.credentials | 
					
						
							|  |  |  |                 if not obfuscated: | 
					
						
							|  |  |  |                     return credentials | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Obfuscate credentials | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 return self.obfuscated_credentials( | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                     credentials=credentials, | 
					
						
							|  |  |  |                     credential_form_schemas=self.provider.model_credential_schema.credential_form_schemas | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     if self.provider.model_credential_schema | 
					
						
							|  |  |  |                     else [], | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def custom_model_credentials_validate( | 
					
						
							|  |  |  |         self, model_type: ModelType, model: str, credentials: dict | 
					
						
							|  |  |  |     ) -> tuple[ProviderModel, dict]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Validate custom model credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :param credentials: model credentials | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # get provider model | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_model_record = ( | 
					
						
							|  |  |  |             db.session.query(ProviderModel) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 ProviderModel.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 ProviderModel.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 ProviderModel.model_name == model, | 
					
						
							|  |  |  |                 ProviderModel.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Get provider credential secret variables | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         provider_credential_secret_variables = self.extract_secret_variables( | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             self.provider.model_credential_schema.credential_form_schemas | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if self.provider.model_credential_schema | 
					
						
							|  |  |  |             else [] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if provider_model_record: | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 original_credentials = ( | 
					
						
							|  |  |  |                     json.loads(provider_model_record.encrypted_config) if provider_model_record.encrypted_config else {} | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             except JSONDecodeError: | 
					
						
							|  |  |  |                 original_credentials = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # decrypt credentials | 
					
						
							|  |  |  |             for key, value in credentials.items(): | 
					
						
							|  |  |  |                 if key in provider_credential_secret_variables: | 
					
						
							|  |  |  |                     # if send [__HIDDEN__] in secret input, it will be same as original value | 
					
						
							| 
									
										
										
										
											2024-08-07 17:30:56 +08:00
										 |  |  |                     if value == HIDDEN_VALUE and key in original_credentials: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                         credentials[key] = encrypter.decrypt_token(self.tenant_id, original_credentials[key]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-12 16:48:38 +08:00
										 |  |  |         credentials = model_provider_factory.model_credentials_validate( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for key, value in credentials.items(): | 
					
						
							|  |  |  |             if key in provider_credential_secret_variables: | 
					
						
							|  |  |  |                 credentials[key] = encrypter.encrypt_token(self.tenant_id, value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return provider_model_record, credentials | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def add_or_update_custom_model_credentials(self, model_type: ModelType, model: str, credentials: dict) -> None: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Add or update custom model credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :param credentials: model credentials | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # validate custom model config | 
					
						
							|  |  |  |         provider_model_record, credentials = self.custom_model_credentials_validate(model_type, model, credentials) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # save provider model | 
					
						
							|  |  |  |         # Note: Do not switch the preferred provider, which allows users to use quotas first | 
					
						
							|  |  |  |         if provider_model_record: | 
					
						
							|  |  |  |             provider_model_record.encrypted_config = json.dumps(credentials) | 
					
						
							|  |  |  |             provider_model_record.is_valid = True | 
					
						
							| 
									
										
										
										
											2024-04-12 16:22:24 +08:00
										 |  |  |             provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             db.session.commit() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             provider_model_record = ProviderModel( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 provider_name=self.provider.provider, | 
					
						
							|  |  |  |                 model_name=model, | 
					
						
							|  |  |  |                 model_type=model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 encrypted_config=json.dumps(credentials), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 is_valid=True, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             db.session.add(provider_model_record) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |         provider_model_credentials_cache = ProviderCredentialsCache( | 
					
						
							|  |  |  |             tenant_id=self.tenant_id, | 
					
						
							|  |  |  |             identity_id=provider_model_record.id, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             cache_type=ProviderCredentialsCacheType.MODEL, | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         provider_model_credentials_cache.delete() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |     def delete_custom_model_credentials(self, model_type: ModelType, model: str) -> None: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Delete custom model credentials. | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # get provider model | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_model_record = ( | 
					
						
							|  |  |  |             db.session.query(ProviderModel) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 ProviderModel.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 ProviderModel.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 ProviderModel.model_name == model, | 
					
						
							|  |  |  |                 ProviderModel.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # delete provider model | 
					
						
							|  |  |  |         if provider_model_record: | 
					
						
							|  |  |  |             db.session.delete(provider_model_record) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             provider_model_credentials_cache = ProviderCredentialsCache( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 identity_id=provider_model_record.id, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 cache_type=ProviderCredentialsCacheType.MODEL, | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             provider_model_credentials_cache.delete() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |     def enable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Enable model. | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         model_setting = ( | 
					
						
							|  |  |  |             db.session.query(ProviderModelSetting) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 ProviderModelSetting.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 ProviderModelSetting.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 ProviderModelSetting.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 ProviderModelSetting.model_name == model, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if model_setting: | 
					
						
							|  |  |  |             model_setting.enabled = True | 
					
						
							|  |  |  |             model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             model_setting = ProviderModelSetting( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 provider_name=self.provider.provider, | 
					
						
							|  |  |  |                 model_type=model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 model_name=model, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 enabled=True, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             db.session.add(model_setting) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return model_setting | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def disable_model(self, model_type: ModelType, model: str) -> ProviderModelSetting: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Disable model. | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         model_setting = ( | 
					
						
							|  |  |  |             db.session.query(ProviderModelSetting) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 ProviderModelSetting.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 ProviderModelSetting.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 ProviderModelSetting.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 ProviderModelSetting.model_name == model, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if model_setting: | 
					
						
							|  |  |  |             model_setting.enabled = False | 
					
						
							|  |  |  |             model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             model_setting = ProviderModelSetting( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 provider_name=self.provider.provider, | 
					
						
							|  |  |  |                 model_type=model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 model_name=model, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 enabled=False, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             db.session.add(model_setting) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return model_setting | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_provider_model_setting(self, model_type: ModelType, model: str) -> Optional[ProviderModelSetting]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get provider model setting. | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         return ( | 
					
						
							|  |  |  |             db.session.query(ProviderModelSetting) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 ProviderModelSetting.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 ProviderModelSetting.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 ProviderModelSetting.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 ProviderModelSetting.model_name == model, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def enable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Enable model load balancing. | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         load_balancing_config_count = ( | 
					
						
							|  |  |  |             db.session.query(LoadBalancingModelConfig) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 LoadBalancingModelConfig.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 LoadBalancingModelConfig.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 LoadBalancingModelConfig.model_name == model, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .count() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if load_balancing_config_count <= 1: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             raise ValueError("Model load balancing configuration must be more than 1.") | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         model_setting = ( | 
					
						
							|  |  |  |             db.session.query(ProviderModelSetting) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 ProviderModelSetting.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 ProviderModelSetting.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 ProviderModelSetting.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 ProviderModelSetting.model_name == model, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if model_setting: | 
					
						
							|  |  |  |             model_setting.load_balancing_enabled = True | 
					
						
							|  |  |  |             model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             model_setting = ProviderModelSetting( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 provider_name=self.provider.provider, | 
					
						
							|  |  |  |                 model_type=model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 model_name=model, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 load_balancing_enabled=True, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             db.session.add(model_setting) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return model_setting | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def disable_model_load_balancing(self, model_type: ModelType, model: str) -> ProviderModelSetting: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Disable model load balancing. | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         model_setting = ( | 
					
						
							|  |  |  |             db.session.query(ProviderModelSetting) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 ProviderModelSetting.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 ProviderModelSetting.provider_name == self.provider.provider, | 
					
						
							|  |  |  |                 ProviderModelSetting.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 ProviderModelSetting.model_name == model, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if model_setting: | 
					
						
							|  |  |  |             model_setting.load_balancing_enabled = False | 
					
						
							|  |  |  |             model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             model_setting = ProviderModelSetting( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 provider_name=self.provider.provider, | 
					
						
							|  |  |  |                 model_type=model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                 model_name=model, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 load_balancing_enabled=False, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             db.session.add(model_setting) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return model_setting | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |     def get_provider_instance(self) -> ModelProvider: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get provider instance. | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         return model_provider_factory.get_provider_instance(self.provider.provider) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_model_type_instance(self, model_type: ModelType) -> AIModel: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get current model type instance. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get provider instance | 
					
						
							|  |  |  |         provider_instance = self.get_provider_instance() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get model instance of LLM | 
					
						
							|  |  |  |         return provider_instance.get_model_instance(model_type) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def switch_preferred_provider_type(self, provider_type: ProviderType) -> None: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Switch preferred provider type. | 
					
						
							|  |  |  |         :param provider_type: | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if provider_type == self.preferred_provider_type: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if provider_type == ProviderType.SYSTEM and not self.system_configuration.enabled: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # get preferred provider | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         preferred_model_provider = ( | 
					
						
							|  |  |  |             db.session.query(TenantPreferredModelProvider) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 TenantPreferredModelProvider.tenant_id == self.tenant_id, | 
					
						
							|  |  |  |                 TenantPreferredModelProvider.provider_name == self.provider.provider, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if preferred_model_provider: | 
					
						
							|  |  |  |             preferred_model_provider.preferred_provider_type = provider_type.value | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             preferred_model_provider = TenantPreferredModelProvider( | 
					
						
							|  |  |  |                 tenant_id=self.tenant_id, | 
					
						
							|  |  |  |                 provider_name=self.provider.provider, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 preferred_provider_type=provider_type.value, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |             db.session.add(preferred_model_provider) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |     def extract_secret_variables(self, credential_form_schemas: list[CredentialFormSchema]) -> list[str]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Extract secret input form variables. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param credential_form_schemas: | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         secret_input_form_variables = [] | 
					
						
							|  |  |  |         for credential_form_schema in credential_form_schemas: | 
					
						
							|  |  |  |             if credential_form_schema.type == FormType.SECRET_INPUT: | 
					
						
							|  |  |  |                 secret_input_form_variables.append(credential_form_schema.variable) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return secret_input_form_variables | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |     def obfuscated_credentials(self, credentials: dict, credential_form_schemas: list[CredentialFormSchema]) -> dict: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Obfuscated credentials. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param credentials: credentials | 
					
						
							|  |  |  |         :param credential_form_schemas: credential form schemas | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get provider credential secret variables | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         credential_secret_variables = self.extract_secret_variables(credential_form_schemas) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Obfuscate provider credentials | 
					
						
							|  |  |  |         copy_credentials = credentials.copy() | 
					
						
							|  |  |  |         for key, value in copy_credentials.items(): | 
					
						
							|  |  |  |             if key in credential_secret_variables: | 
					
						
							|  |  |  |                 copy_credentials[key] = encrypter.obfuscated_token(value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return copy_credentials | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def get_provider_model( | 
					
						
							|  |  |  |         self, model_type: ModelType, model: str, only_active: bool = False | 
					
						
							|  |  |  |     ) -> Optional[ModelWithProviderEntity]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get provider model. | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :param only_active: return active model only | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         provider_models = self.get_provider_models(model_type, only_active) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for provider_model in provider_models: | 
					
						
							|  |  |  |             if provider_model.model == model: | 
					
						
							|  |  |  |                 return provider_model | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def get_provider_models( | 
					
						
							|  |  |  |         self, model_type: Optional[ModelType] = None, only_active: bool = False | 
					
						
							|  |  |  |     ) -> list[ModelWithProviderEntity]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get provider models. | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param only_active: only active models | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         provider_instance = self.get_provider_instance() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_types = [] | 
					
						
							|  |  |  |         if model_type: | 
					
						
							|  |  |  |             model_types.append(model_type) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             model_types = provider_instance.get_provider_schema().supported_model_types | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         # Group model settings by model type and model | 
					
						
							|  |  |  |         model_setting_map = defaultdict(dict) | 
					
						
							|  |  |  |         for model_setting in self.model_settings: | 
					
						
							|  |  |  |             model_setting_map[model_setting.model_type][model_setting.model] = model_setting | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         if self.using_provider_type == ProviderType.SYSTEM: | 
					
						
							|  |  |  |             provider_models = self._get_system_provider_models( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             provider_models = self._get_custom_provider_models( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 model_types=model_types, provider_instance=provider_instance, model_setting_map=model_setting_map | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if only_active: | 
					
						
							|  |  |  |             provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # resort provider_models | 
					
						
							|  |  |  |         return sorted(provider_models, key=lambda x: x.model_type.value) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def _get_system_provider_models( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         model_types: list[ModelType], | 
					
						
							|  |  |  |         provider_instance: ModelProvider, | 
					
						
							|  |  |  |         model_setting_map: dict[ModelType, dict[str, ModelSettings]], | 
					
						
							|  |  |  |     ) -> list[ModelWithProviderEntity]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get system provider models. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param model_types: model types | 
					
						
							|  |  |  |         :param provider_instance: provider instance | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         :param model_setting_map: model setting map | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         provider_models = [] | 
					
						
							|  |  |  |         for model_type in model_types: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             for m in provider_instance.models(model_type): | 
					
						
							|  |  |  |                 status = ModelStatus.ACTIVE | 
					
						
							|  |  |  |                 if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | 
					
						
							|  |  |  |                     model_setting = model_setting_map[m.model_type][m.model] | 
					
						
							|  |  |  |                     if model_setting.enabled is False: | 
					
						
							|  |  |  |                         status = ModelStatus.DISABLED | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 provider_models.append( | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                     ModelWithProviderEntity( | 
					
						
							| 
									
										
										
										
											2024-01-05 12:13:45 +08:00
										 |  |  |                         model=m.model, | 
					
						
							|  |  |  |                         label=m.label, | 
					
						
							|  |  |  |                         model_type=m.model_type, | 
					
						
							|  |  |  |                         features=m.features, | 
					
						
							|  |  |  |                         fetch_from=m.fetch_from, | 
					
						
							|  |  |  |                         model_properties=m.model_properties, | 
					
						
							|  |  |  |                         deprecated=m.deprecated, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                         provider=SimpleModelProviderEntity(self.provider), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         status=status, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                     ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |         if self.provider.provider not in original_provider_configurate_methods: | 
					
						
							|  |  |  |             original_provider_configurate_methods[self.provider.provider] = [] | 
					
						
							|  |  |  |             for configurate_method in provider_instance.get_provider_schema().configurate_methods: | 
					
						
							|  |  |  |                 original_provider_configurate_methods[self.provider.provider].append(configurate_method) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         should_use_custom_model = False | 
					
						
							|  |  |  |         if original_provider_configurate_methods[self.provider.provider] == [ConfigurateMethod.CUSTOMIZABLE_MODEL]: | 
					
						
							|  |  |  |             should_use_custom_model = True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         for quota_configuration in self.system_configuration.quota_configurations: | 
					
						
							|  |  |  |             if self.system_configuration.current_quota_type != quota_configuration.quota_type: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |             restrict_models = quota_configuration.restrict_models | 
					
						
							|  |  |  |             if len(restrict_models) == 0: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 break | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |             if should_use_custom_model: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 if original_provider_configurate_methods[self.provider.provider] == [ | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     ConfigurateMethod.CUSTOMIZABLE_MODEL | 
					
						
							|  |  |  |                 ]: | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |                     # only customizable model | 
					
						
							|  |  |  |                     for restrict_model in restrict_models: | 
					
						
							|  |  |  |                         copy_credentials = self.system_configuration.credentials.copy() | 
					
						
							|  |  |  |                         if restrict_model.base_model_name: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             copy_credentials["base_model_name"] = restrict_model.base_model_name | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                         try: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             custom_model_schema = provider_instance.get_model_instance( | 
					
						
							|  |  |  |                                 restrict_model.model_type | 
					
						
							|  |  |  |                             ).get_customizable_model_schema_from_credentials(restrict_model.model, copy_credentials) | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |                         except Exception as ex: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             logger.warning(f"get custom model schema failed, {ex}") | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |                             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if not custom_model_schema: | 
					
						
							|  |  |  |                             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if custom_model_schema.model_type not in model_types: | 
					
						
							|  |  |  |                             continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                         status = ModelStatus.ACTIVE | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         if ( | 
					
						
							|  |  |  |                             custom_model_schema.model_type in model_setting_map | 
					
						
							|  |  |  |                             and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] | 
					
						
							|  |  |  |                         ): | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                             model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] | 
					
						
							|  |  |  |                             if model_setting.enabled is False: | 
					
						
							|  |  |  |                                 status = ModelStatus.DISABLED | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |                         provider_models.append( | 
					
						
							|  |  |  |                             ModelWithProviderEntity( | 
					
						
							|  |  |  |                                 model=custom_model_schema.model, | 
					
						
							|  |  |  |                                 label=custom_model_schema.label, | 
					
						
							|  |  |  |                                 model_type=custom_model_schema.model_type, | 
					
						
							|  |  |  |                                 features=custom_model_schema.features, | 
					
						
							|  |  |  |                                 fetch_from=FetchFrom.PREDEFINED_MODEL, | 
					
						
							|  |  |  |                                 model_properties=custom_model_schema.model_properties, | 
					
						
							|  |  |  |                                 deprecated=custom_model_schema.deprecated, | 
					
						
							|  |  |  |                                 provider=SimpleModelProviderEntity(self.provider), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                 status=status, | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |                             ) | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             # if llm name not in restricted llm list, remove it | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |             restrict_model_names = [rm.model for rm in restrict_models] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             for m in provider_models: | 
					
						
							| 
									
										
										
										
											2024-01-09 19:17:47 +08:00
										 |  |  |                 if m.model_type == ModelType.LLM and m.model not in restrict_model_names: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                     m.status = ModelStatus.NO_PERMISSION | 
					
						
							|  |  |  |                 elif not quota_configuration.is_valid: | 
					
						
							|  |  |  |                     m.status = ModelStatus.QUOTA_EXCEEDED | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         return provider_models | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def _get_custom_provider_models( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         model_types: list[ModelType], | 
					
						
							|  |  |  |         provider_instance: ModelProvider, | 
					
						
							|  |  |  |         model_setting_map: dict[ModelType, dict[str, ModelSettings]], | 
					
						
							|  |  |  |     ) -> list[ModelWithProviderEntity]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get custom provider models. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param model_types: model types | 
					
						
							|  |  |  |         :param provider_instance: provider instance | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         :param model_setting_map: model setting map | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         provider_models = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         credentials = None | 
					
						
							|  |  |  |         if self.custom_configuration.provider: | 
					
						
							|  |  |  |             credentials = self.custom_configuration.provider.credentials | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for model_type in model_types: | 
					
						
							|  |  |  |             if model_type not in self.provider.supported_model_types: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             models = provider_instance.models(model_type) | 
					
						
							|  |  |  |             for m in models: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 status = ModelStatus.ACTIVE if credentials else ModelStatus.NO_CONFIGURE | 
					
						
							|  |  |  |                 load_balancing_enabled = False | 
					
						
							|  |  |  |                 if m.model_type in model_setting_map and m.model in model_setting_map[m.model_type]: | 
					
						
							|  |  |  |                     model_setting = model_setting_map[m.model_type][m.model] | 
					
						
							|  |  |  |                     if model_setting.enabled is False: | 
					
						
							|  |  |  |                         status = ModelStatus.DISABLED | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if len(model_setting.load_balancing_configs) > 1: | 
					
						
							|  |  |  |                         load_balancing_enabled = True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 provider_models.append( | 
					
						
							|  |  |  |                     ModelWithProviderEntity( | 
					
						
							| 
									
										
										
										
											2024-01-05 12:13:45 +08:00
										 |  |  |                         model=m.model, | 
					
						
							|  |  |  |                         label=m.label, | 
					
						
							|  |  |  |                         model_type=m.model_type, | 
					
						
							|  |  |  |                         features=m.features, | 
					
						
							|  |  |  |                         fetch_from=m.fetch_from, | 
					
						
							|  |  |  |                         model_properties=m.model_properties, | 
					
						
							|  |  |  |                         deprecated=m.deprecated, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                         provider=SimpleModelProviderEntity(self.provider), | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                         status=status, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         load_balancing_enabled=load_balancing_enabled, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # custom models | 
					
						
							|  |  |  |         for model_configuration in self.custom_configuration.models: | 
					
						
							|  |  |  |             if model_configuration.model_type not in model_types: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 10:53:50 +08:00
										 |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 custom_model_schema = provider_instance.get_model_instance( | 
					
						
							|  |  |  |                     model_configuration.model_type | 
					
						
							|  |  |  |                 ).get_customizable_model_schema_from_credentials( | 
					
						
							|  |  |  |                     model_configuration.model, model_configuration.credentials | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-01-04 10:53:50 +08:00
										 |  |  |             except Exception as ex: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 logger.warning(f"get custom model schema failed, {ex}") | 
					
						
							| 
									
										
										
										
											2024-01-04 10:53:50 +08:00
										 |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if not custom_model_schema: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             status = ModelStatus.ACTIVE | 
					
						
							|  |  |  |             load_balancing_enabled = False | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if ( | 
					
						
							|  |  |  |                 custom_model_schema.model_type in model_setting_map | 
					
						
							|  |  |  |                 and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] | 
					
						
							|  |  |  |             ): | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] | 
					
						
							|  |  |  |                 if model_setting.enabled is False: | 
					
						
							|  |  |  |                     status = ModelStatus.DISABLED | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if len(model_setting.load_balancing_configs) > 1: | 
					
						
							|  |  |  |                     load_balancing_enabled = True | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             provider_models.append( | 
					
						
							|  |  |  |                 ModelWithProviderEntity( | 
					
						
							| 
									
										
										
										
											2024-01-05 12:13:45 +08:00
										 |  |  |                     model=custom_model_schema.model, | 
					
						
							|  |  |  |                     label=custom_model_schema.label, | 
					
						
							|  |  |  |                     model_type=custom_model_schema.model_type, | 
					
						
							|  |  |  |                     features=custom_model_schema.features, | 
					
						
							|  |  |  |                     fetch_from=custom_model_schema.fetch_from, | 
					
						
							|  |  |  |                     model_properties=custom_model_schema.model_properties, | 
					
						
							|  |  |  |                     deprecated=custom_model_schema.deprecated, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                     provider=SimpleModelProviderEntity(self.provider), | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                     status=status, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     load_balancing_enabled=load_balancing_enabled, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return provider_models | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ProviderConfigurations(BaseModel): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Model class for provider configuration dict. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |     tenant_id: str | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  |     configurations: dict[str, ProviderConfiguration] = {} | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, tenant_id: str): | 
					
						
							|  |  |  |         super().__init__(tenant_id=tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def get_models( | 
					
						
							|  |  |  |         self, provider: Optional[str] = None, model_type: Optional[ModelType] = None, only_active: bool = False | 
					
						
							|  |  |  |     ) -> list[ModelWithProviderEntity]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get available models. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         If preferred provider type is `system`: | 
					
						
							|  |  |  |           Get the current **system mode** if provider supported, | 
					
						
							|  |  |  |           if all system modes are not available (no quota), it is considered to be the **custom credential mode**. | 
					
						
							|  |  |  |           If there is no model configured in custom mode, it is treated as no_configure. | 
					
						
							|  |  |  |         system > custom > no_configure | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         If preferred provider type is `custom`: | 
					
						
							|  |  |  |           If custom credentials are configured, it is treated as custom mode. | 
					
						
							|  |  |  |           Otherwise, get the current **system mode** if supported, | 
					
						
							|  |  |  |           If all system modes are not available (no quota), it is treated as no_configure. | 
					
						
							|  |  |  |         custom > system > no_configure | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         If real mode is `system`, use system credentials to get models, | 
					
						
							|  |  |  |           paid quotas > provider free quotas > system free quotas | 
					
						
							|  |  |  |           include pre-defined models (exclude GPT-4, status marked as `no_permission`). | 
					
						
							|  |  |  |         If real mode is `custom`, use workspace custom credentials to get models, | 
					
						
							|  |  |  |           include pre-defined models, custom models(manual append). | 
					
						
							|  |  |  |         If real mode is `no_configure`, only return pre-defined models from `model runtime`. | 
					
						
							|  |  |  |           (model status marked as `no_configure` if preferred provider type is `custom` otherwise `quota_exceeded`) | 
					
						
							|  |  |  |         model status marked as `active` is available. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param provider: provider name | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param only_active: only active models | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         all_models = [] | 
					
						
							|  |  |  |         for provider_configuration in self.values(): | 
					
						
							|  |  |  |             if provider and provider_configuration.provider.provider != provider: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             all_models.extend(provider_configuration.get_provider_models(model_type, only_active)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return all_models | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-09 15:21:33 +08:00
										 |  |  |     def to_list(self) -> list[ProviderConfiguration]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Convert to list. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         return list(self.values()) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __getitem__(self, key): | 
					
						
							|  |  |  |         return self.configurations[key] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __setitem__(self, key, value): | 
					
						
							|  |  |  |         self.configurations[key] = value | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __iter__(self): | 
					
						
							|  |  |  |         return iter(self.configurations) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def values(self) -> Iterator[ProviderConfiguration]: | 
					
						
							|  |  |  |         return self.configurations.values() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get(self, key, default=None): | 
					
						
							|  |  |  |         return self.configurations.get(key, default) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ProviderModelBundle(BaseModel): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     Provider model bundle. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |     configuration: ProviderConfiguration | 
					
						
							|  |  |  |     provider_instance: ModelProvider | 
					
						
							|  |  |  |     model_type_instance: AIModel | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-14 01:05:37 +08:00
										 |  |  |     # pydantic configs | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) |