| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | import json | 
					
						
							|  |  |  | from collections import defaultdict | 
					
						
							|  |  |  | from json import JSONDecodeError | 
					
						
							|  |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from sqlalchemy.exc import IntegrityError | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-20 23:16:43 -04:00
										 |  |  | from configs import dify_config | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity | 
					
						
							| 
									
										
										
										
											2024-01-12 12:34:01 +08:00
										 |  |  | from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from core.entities.provider_entities import ( | 
					
						
							|  |  |  |     CustomConfiguration, | 
					
						
							|  |  |  |     CustomModelConfiguration, | 
					
						
							|  |  |  |     CustomProviderConfiguration, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |     ModelLoadBalancingConfiguration, | 
					
						
							|  |  |  |     ModelSettings, | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  |     QuotaConfiguration, | 
					
						
							|  |  |  |     SystemConfiguration, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.helper import encrypter | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  | from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType | 
					
						
							| 
									
										
										
										
											2024-08-20 23:16:43 -04:00
										 |  |  | from core.helper.position_helper import is_filtered | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_runtime.entities.model_entities import ModelType | 
					
						
							| 
									
										
										
										
											2024-08-20 23:16:43 -04:00
										 |  |  | from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderEntity | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | from core.model_runtime.model_providers import model_provider_factory | 
					
						
							|  |  |  | from extensions import ext_hosting_provider | 
					
						
							|  |  |  | from extensions.ext_database import db | 
					
						
							| 
									
										
										
										
											2024-06-05 02:06:19 +08:00
										 |  |  | from extensions.ext_redis import redis_client | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  | from models.provider import ( | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |     LoadBalancingModelConfig, | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  |     Provider, | 
					
						
							|  |  |  |     ProviderModel, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |     ProviderModelSetting, | 
					
						
							| 
									
										
										
										
											2024-02-06 13:21:13 +08:00
										 |  |  |     ProviderQuotaType, | 
					
						
							|  |  |  |     ProviderType, | 
					
						
							|  |  |  |     TenantDefaultModel, | 
					
						
							|  |  |  |     TenantPreferredModelProvider, | 
					
						
							|  |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | from services.feature_service import FeatureService | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class ProviderManager: | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. | 
					
						
							|  |  |  |     """
 | 
					
						
							| 
									
										
										
										
											2024-08-20 23:16:43 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-05 12:13:45 +08:00
										 |  |  |     def __init__(self) -> None: | 
					
						
							|  |  |  |         self.decoding_rsa_key = None | 
					
						
							|  |  |  |         self.decoding_cipher_rsa = None | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_configurations(self, tenant_id: str) -> ProviderConfigurations: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get model provider configurations. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Construct ProviderConfiguration objects for each provider | 
					
						
							|  |  |  |         Including: | 
					
						
							|  |  |  |         1. Basic information of the provider | 
					
						
							|  |  |  |         2. Hosting configuration information, including: | 
					
						
							|  |  |  |           (1. Whether to enable (support) hosting type, if enabled, the following information exists | 
					
						
							|  |  |  |           (2. List of hosting type provider configurations | 
					
						
							|  |  |  |               (including quota type, quota limit, current remaining quota, etc.) | 
					
						
							|  |  |  |           (3. The current hosting type in use (whether there is a quota or not) | 
					
						
							|  |  |  |               paid quotas > provider free quotas > hosting trial quotas | 
					
						
							|  |  |  |           (4. Unified credentials for hosting providers | 
					
						
							|  |  |  |         3. Custom configuration information, including: | 
					
						
							|  |  |  |           (1. Whether to enable (support) custom type, if enabled, the following information exists | 
					
						
							|  |  |  |           (2. Custom provider configuration (including credentials) | 
					
						
							|  |  |  |           (3. List of custom provider model configurations (including credentials) | 
					
						
							|  |  |  |         4. Hosting/custom preferred provider type. | 
					
						
							|  |  |  |         Provide methods: | 
					
						
							|  |  |  |         - Get the current configuration (including credentials) | 
					
						
							|  |  |  |         - Get the availability and status of the hosting configuration: active available, | 
					
						
							|  |  |  |           quota_exceeded insufficient quota, unsupported hosting | 
					
						
							|  |  |  |         - Get the availability of custom configuration | 
					
						
							|  |  |  |           Custom provider available conditions: | 
					
						
							|  |  |  |           (1. custom provider credentials available | 
					
						
							|  |  |  |           (2. at least one custom model credentials available | 
					
						
							|  |  |  |         - Verify, update, and delete custom provider configuration | 
					
						
							|  |  |  |         - Verify, update, and delete custom provider model configuration | 
					
						
							|  |  |  |         - Get the list of available models (optional provider filtering, model type filtering) | 
					
						
							|  |  |  |           Append custom provider models to the list | 
					
						
							|  |  |  |         - Get provider instance | 
					
						
							|  |  |  |         - Switch selection priority | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get all provider records of the workspace | 
					
						
							|  |  |  |         provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Initialize trial provider records if not exist | 
					
						
							|  |  |  |         provider_name_to_provider_records_dict = self._init_trial_provider_records( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             tenant_id, provider_name_to_provider_records_dict | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get all provider model records of the workspace | 
					
						
							|  |  |  |         provider_name_to_provider_model_records_dict = self._get_all_provider_models(tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get all provider entities | 
					
						
							|  |  |  |         provider_entities = model_provider_factory.get_providers() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get All preferred provider types of the workspace | 
					
						
							|  |  |  |         provider_name_to_preferred_model_provider_records_dict = self._get_all_preferred_model_providers(tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         # Get All provider model settings | 
					
						
							|  |  |  |         provider_name_to_provider_model_settings_dict = self._get_all_provider_model_settings(tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get All load balancing configs | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_name_to_provider_load_balancing_model_configs_dict = self._get_all_provider_load_balancing_configs( | 
					
						
							|  |  |  |             tenant_id | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_configurations = ProviderConfigurations(tenant_id=tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         # Construct ProviderConfiguration objects for each provider | 
					
						
							|  |  |  |         for provider_entity in provider_entities: | 
					
						
							| 
									
										
										
										
											2024-08-20 23:16:43 -04:00
										 |  |  |             # handle include, exclude | 
					
						
							|  |  |  |             if is_filtered( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET, | 
					
						
							|  |  |  |                 exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET, | 
					
						
							|  |  |  |                 data=provider_entity, | 
					
						
							|  |  |  |                 name_func=lambda x: x.provider, | 
					
						
							| 
									
										
										
										
											2024-08-20 23:16:43 -04:00
										 |  |  |             ): | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             provider_name = provider_entity.provider | 
					
						
							| 
									
										
										
										
											2024-05-22 11:18:03 +08:00
										 |  |  |             provider_records = provider_name_to_provider_records_dict.get(provider_entity.provider, []) | 
					
						
							|  |  |  |             provider_model_records = provider_name_to_provider_model_records_dict.get(provider_entity.provider, []) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Convert to custom configuration | 
					
						
							|  |  |  |             custom_configuration = self._to_custom_configuration( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 tenant_id, provider_entity, provider_records, provider_model_records | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Convert to system configuration | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             system_configuration = self._to_system_configuration(tenant_id, provider_entity, provider_records) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Get preferred provider type | 
					
						
							|  |  |  |             preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if preferred_provider_type_record: | 
					
						
							|  |  |  |                 preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) | 
					
						
							| 
									
										
										
										
											2024-05-22 11:18:03 +08:00
										 |  |  |             elif custom_configuration.provider or custom_configuration.models: | 
					
						
							|  |  |  |                 preferred_provider_type = ProviderType.CUSTOM | 
					
						
							|  |  |  |             elif system_configuration.enabled: | 
					
						
							|  |  |  |                 preferred_provider_type = ProviderType.SYSTEM | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2024-05-22 11:18:03 +08:00
										 |  |  |                 preferred_provider_type = ProviderType.CUSTOM | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             using_provider_type = preferred_provider_type | 
					
						
							| 
									
										
										
										
											2024-05-22 11:18:03 +08:00
										 |  |  |             has_valid_quota = any(quota_conf.is_valid for quota_conf in system_configuration.quota_configurations) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             if preferred_provider_type == ProviderType.SYSTEM: | 
					
						
							| 
									
										
										
										
											2024-05-22 11:18:03 +08:00
										 |  |  |                 if not system_configuration.enabled or not has_valid_quota: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                     using_provider_type = ProviderType.CUSTOM | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 if not custom_configuration.provider and not custom_configuration.models: | 
					
						
							| 
									
										
										
										
											2024-05-22 11:18:03 +08:00
										 |  |  |                     if system_configuration.enabled and has_valid_quota: | 
					
						
							|  |  |  |                         using_provider_type = ProviderType.SYSTEM | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             # Get provider load balancing configs | 
					
						
							|  |  |  |             provider_model_settings = provider_name_to_provider_model_settings_dict.get(provider_name) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Get provider load balancing configs | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             provider_load_balancing_configs = provider_name_to_provider_load_balancing_model_configs_dict.get( | 
					
						
							|  |  |  |                 provider_name | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Convert to model settings | 
					
						
							|  |  |  |             model_settings = self._to_model_settings( | 
					
						
							|  |  |  |                 provider_entity=provider_entity, | 
					
						
							|  |  |  |                 provider_model_settings=provider_model_settings, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 load_balancing_model_configs=provider_load_balancing_configs, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             provider_configuration = ProviderConfiguration( | 
					
						
							|  |  |  |                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                 provider=provider_entity, | 
					
						
							|  |  |  |                 preferred_provider_type=preferred_provider_type, | 
					
						
							|  |  |  |                 using_provider_type=using_provider_type, | 
					
						
							|  |  |  |                 system_configuration=system_configuration, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 custom_configuration=custom_configuration, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 model_settings=model_settings, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             provider_configurations[provider_name] = provider_configuration | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Return the encapsulated object | 
					
						
							|  |  |  |         return provider_configurations | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_provider_model_bundle(self, tenant_id: str, provider: str, model_type: ModelType) -> ProviderModelBundle: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get provider model bundle. | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :param provider: provider name | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         provider_configurations = self.get_configurations(tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # get provider instance | 
					
						
							|  |  |  |         provider_configuration = provider_configurations.get(provider) | 
					
						
							|  |  |  |         if not provider_configuration: | 
					
						
							|  |  |  |             raise ValueError(f"Provider {provider} does not exist.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         provider_instance = provider_configuration.get_provider_instance() | 
					
						
							|  |  |  |         model_type_instance = provider_instance.get_model_instance(model_type) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return ProviderModelBundle( | 
					
						
							|  |  |  |             configuration=provider_configuration, | 
					
						
							|  |  |  |             provider_instance=provider_instance, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             model_type_instance=model_type_instance, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_default_model(self, tenant_id: str, model_type: ModelType) -> Optional[DefaultModelEntity]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get default model. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get the corresponding TenantDefaultModel record | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         default_model = ( | 
					
						
							|  |  |  |             db.session.query(TenantDefaultModel) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 TenantDefaultModel.tenant_id == tenant_id, | 
					
						
							|  |  |  |                 TenantDefaultModel.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # If it does not exist, get the first available provider model from get_configurations | 
					
						
							|  |  |  |         # and update the TenantDefaultModel record | 
					
						
							|  |  |  |         if not default_model: | 
					
						
							|  |  |  |             # Get provider configurations | 
					
						
							|  |  |  |             provider_configurations = self.get_configurations(tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # get available models from provider_configurations | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             available_models = provider_configurations.get_models(model_type=model_type, only_active=True) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if available_models: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 available_model = next( | 
					
						
							|  |  |  |                     (model for model in available_models if model.model == "gpt-4"), available_models[0] | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-05-22 11:18:03 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 default_model = TenantDefaultModel( | 
					
						
							|  |  |  |                     tenant_id=tenant_id, | 
					
						
							|  |  |  |                     model_type=model_type.to_origin_model_type(), | 
					
						
							|  |  |  |                     provider_name=available_model.provider.provider, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     model_name=available_model.model, | 
					
						
							| 
									
										
										
										
											2024-05-22 11:18:03 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  |                 db.session.add(default_model) | 
					
						
							|  |  |  |                 db.session.commit() | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         if not default_model: | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         provider_instance = model_provider_factory.get_provider_instance(default_model.provider_name) | 
					
						
							| 
									
										
										
										
											2024-01-05 09:43:41 +08:00
										 |  |  |         provider_schema = provider_instance.get_provider_schema() | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return DefaultModelEntity( | 
					
						
							|  |  |  |             model=default_model.model_name, | 
					
						
							|  |  |  |             model_type=model_type, | 
					
						
							| 
									
										
										
										
											2024-01-05 09:43:41 +08:00
										 |  |  |             provider=DefaultModelProviderEntity( | 
					
						
							|  |  |  |                 provider=provider_schema.provider, | 
					
						
							|  |  |  |                 label=provider_schema.label, | 
					
						
							|  |  |  |                 icon_small=provider_schema.icon_small, | 
					
						
							|  |  |  |                 icon_large=provider_schema.icon_large, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 supported_model_types=provider_schema.supported_model_types, | 
					
						
							|  |  |  |             ), | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-20 23:16:43 -04:00
										 |  |  |     def get_first_provider_first_model(self, tenant_id: str, model_type: ModelType) -> tuple[str, str]: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Get names of first model and its provider | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :return: provider name, model name | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         provider_configurations = self.get_configurations(tenant_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # get available models from provider_configurations | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         all_models = provider_configurations.get_models(model_type=model_type, only_active=False) | 
					
						
							| 
									
										
										
										
											2024-08-20 23:16:43 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return all_models[0].provider.provider, all_models[0].model | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def update_default_model_record( | 
					
						
							|  |  |  |         self, tenant_id: str, model_type: ModelType, provider: str, model: str | 
					
						
							|  |  |  |     ) -> TenantDefaultModel: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Update default model record. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :param model_type: model type | 
					
						
							|  |  |  |         :param provider: provider name | 
					
						
							|  |  |  |         :param model: model name | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         provider_configurations = self.get_configurations(tenant_id) | 
					
						
							|  |  |  |         if provider not in provider_configurations: | 
					
						
							|  |  |  |             raise ValueError(f"Provider {provider} does not exist.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # get available models from provider_configurations | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         available_models = provider_configurations.get_models(model_type=model_type, only_active=True) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # check if the model is exist in available models | 
					
						
							|  |  |  |         model_names = [model.model for model in available_models] | 
					
						
							|  |  |  |         if model not in model_names: | 
					
						
							|  |  |  |             raise ValueError(f"Model {model} does not exist.") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get the list of available models from get_configurations and check if it is LLM | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         default_model = ( | 
					
						
							|  |  |  |             db.session.query(TenantDefaultModel) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 TenantDefaultModel.tenant_id == tenant_id, | 
					
						
							|  |  |  |                 TenantDefaultModel.model_type == model_type.to_origin_model_type(), | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             .first() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # create or update TenantDefaultModel record | 
					
						
							|  |  |  |         if default_model: | 
					
						
							|  |  |  |             # update default model | 
					
						
							|  |  |  |             default_model.provider_name = provider | 
					
						
							|  |  |  |             default_model.model_name = model | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             # create default model | 
					
						
							|  |  |  |             default_model = TenantDefaultModel( | 
					
						
							|  |  |  |                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                 model_type=model_type.value, | 
					
						
							|  |  |  |                 provider_name=provider, | 
					
						
							|  |  |  |                 model_name=model, | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             db.session.add(default_model) | 
					
						
							|  |  |  |             db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return default_model | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 20:45:03 +09:00
										 |  |  |     @staticmethod | 
					
						
							|  |  |  |     def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get all provider records of the workspace. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         provider_name_to_provider_records_dict = defaultdict(list) | 
					
						
							|  |  |  |         for provider in providers: | 
					
						
							|  |  |  |             provider_name_to_provider_records_dict[provider.provider_name].append(provider) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return provider_name_to_provider_records_dict | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 20:45:03 +09:00
										 |  |  |     @staticmethod | 
					
						
							|  |  |  |     def _get_all_provider_models(tenant_id: str) -> dict[str, list[ProviderModel]]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get all provider model records of the workspace. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get all provider model records of the workspace | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_models = ( | 
					
						
							|  |  |  |             db.session.query(ProviderModel) | 
					
						
							|  |  |  |             .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) | 
					
						
							|  |  |  |             .all() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         provider_name_to_provider_model_records_dict = defaultdict(list) | 
					
						
							|  |  |  |         for provider_model in provider_models: | 
					
						
							|  |  |  |             provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return provider_name_to_provider_model_records_dict | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 20:45:03 +09:00
										 |  |  |     @staticmethod | 
					
						
							|  |  |  |     def _get_all_preferred_model_providers(tenant_id: str) -> dict[str, TenantPreferredModelProvider]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get All preferred provider types of the workspace. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         :param tenant_id: workspace id | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         preferred_provider_types = ( | 
					
						
							|  |  |  |             db.session.query(TenantPreferredModelProvider) | 
					
						
							|  |  |  |             .filter(TenantPreferredModelProvider.tenant_id == tenant_id) | 
					
						
							|  |  |  |             .all() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         provider_name_to_preferred_provider_type_records_dict = { | 
					
						
							|  |  |  |             preferred_provider_type.provider_name: preferred_provider_type | 
					
						
							|  |  |  |             for preferred_provider_type in preferred_provider_types | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return provider_name_to_preferred_provider_type_records_dict | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 20:45:03 +09:00
										 |  |  |     @staticmethod | 
					
						
							|  |  |  |     def _get_all_provider_model_settings(tenant_id: str) -> dict[str, list[ProviderModelSetting]]: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get All provider model settings of the workspace. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_model_settings = ( | 
					
						
							|  |  |  |             db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         provider_name_to_provider_model_settings_dict = defaultdict(list) | 
					
						
							|  |  |  |         for provider_model_setting in provider_model_settings: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             ( | 
					
						
							|  |  |  |                 provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( | 
					
						
							|  |  |  |                     provider_model_setting | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return provider_name_to_provider_model_settings_dict | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 20:45:03 +09:00
										 |  |  |     @staticmethod | 
					
						
							|  |  |  |     def _get_all_provider_load_balancing_configs(tenant_id: str) -> dict[str, list[LoadBalancingModelConfig]]: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Get All provider load balancing configs of the workspace. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-06-05 02:06:19 +08:00
										 |  |  |         cache_key = f"tenant:{tenant_id}:model_load_balancing_enabled" | 
					
						
							|  |  |  |         cache_result = redis_client.get(cache_key) | 
					
						
							|  |  |  |         if cache_result is None: | 
					
						
							|  |  |  |             model_load_balancing_enabled = FeatureService.get_features(tenant_id).model_load_balancing_enabled | 
					
						
							|  |  |  |             redis_client.setex(cache_key, 120, str(model_load_balancing_enabled)) | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             cache_result = cache_result.decode("utf-8") | 
					
						
							|  |  |  |             model_load_balancing_enabled = cache_result == "True" | 
					
						
							| 
									
										
										
										
											2024-06-05 02:06:19 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         if not model_load_balancing_enabled: | 
					
						
							| 
									
										
										
										
											2024-06-27 11:21:31 +08:00
										 |  |  |             return {} | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         provider_load_balancing_configs = ( | 
					
						
							|  |  |  |             db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) | 
					
						
							|  |  |  |         for provider_load_balancing_config in provider_load_balancing_configs: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             ( | 
					
						
							|  |  |  |                 provider_name_to_provider_load_balancing_model_configs_dict[ | 
					
						
							|  |  |  |                     provider_load_balancing_config.provider_name | 
					
						
							|  |  |  |                 ].append(provider_load_balancing_config) | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return provider_name_to_provider_load_balancing_model_configs_dict | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 20:45:03 +09:00
										 |  |  |     @staticmethod | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def _init_trial_provider_records( | 
					
						
							|  |  |  |         tenant_id: str, provider_name_to_provider_records_dict: dict[str, list] | 
					
						
							|  |  |  |     ) -> dict[str, list]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Initialize trial provider records if not exists. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param tenant_id: workspace id | 
					
						
							|  |  |  |         :param provider_name_to_provider_records_dict: provider name to provider records dict | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get hosting configuration | 
					
						
							|  |  |  |         hosting_configuration = ext_hosting_provider.hosting_configuration | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for provider_name, configuration in hosting_configuration.provider_map.items(): | 
					
						
							|  |  |  |             if not configuration.enabled: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             provider_records = provider_name_to_provider_records_dict.get(provider_name) | 
					
						
							|  |  |  |             if not provider_records: | 
					
						
							|  |  |  |                 provider_records = [] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-27 11:21:31 +08:00
										 |  |  |             provider_quota_to_provider_record_dict = {} | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             for provider_record in provider_records: | 
					
						
							|  |  |  |                 if provider_record.provider_type != ProviderType.SYSTEM.value: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 provider_quota_to_provider_record_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( | 
					
						
							|  |  |  |                     provider_record | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             for quota in configuration.quotas: | 
					
						
							|  |  |  |                 if quota.quota_type == ProviderQuotaType.TRIAL: | 
					
						
							|  |  |  |                     # Init trial provider records if not exists | 
					
						
							|  |  |  |                     if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict: | 
					
						
							| 
									
										
										
										
											2024-01-03 09:12:53 +08:00
										 |  |  |                         try: | 
					
						
							|  |  |  |                             provider_record = Provider( | 
					
						
							|  |  |  |                                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                                 provider_name=provider_name, | 
					
						
							|  |  |  |                                 provider_type=ProviderType.SYSTEM.value, | 
					
						
							|  |  |  |                                 quota_type=ProviderQuotaType.TRIAL.value, | 
					
						
							|  |  |  |                                 quota_limit=quota.quota_limit, | 
					
						
							|  |  |  |                                 quota_used=0, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                 is_valid=True, | 
					
						
							| 
									
										
										
										
											2024-01-03 09:12:53 +08:00
										 |  |  |                             ) | 
					
						
							|  |  |  |                             db.session.add(provider_record) | 
					
						
							|  |  |  |                             db.session.commit() | 
					
						
							|  |  |  |                         except IntegrityError: | 
					
						
							|  |  |  |                             db.session.rollback() | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             provider_record = ( | 
					
						
							|  |  |  |                                 db.session.query(Provider) | 
					
						
							| 
									
										
										
										
											2024-01-03 09:12:53 +08:00
										 |  |  |                                 .filter( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                     Provider.tenant_id == tenant_id, | 
					
						
							|  |  |  |                                     Provider.provider_name == provider_name, | 
					
						
							|  |  |  |                                     Provider.provider_type == ProviderType.SYSTEM.value, | 
					
						
							|  |  |  |                                     Provider.quota_type == ProviderQuotaType.TRIAL.value, | 
					
						
							|  |  |  |                                 ) | 
					
						
							|  |  |  |                                 .first() | 
					
						
							|  |  |  |                             ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 17:33:26 +08:00
										 |  |  |                             if provider_record and not provider_record.is_valid: | 
					
						
							|  |  |  |                                 provider_record.is_valid = True | 
					
						
							|  |  |  |                                 db.session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                         provider_name_to_provider_records_dict[provider_name].append(provider_record) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return provider_name_to_provider_records_dict | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def _to_custom_configuration( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         tenant_id: str, | 
					
						
							|  |  |  |         provider_entity: ProviderEntity, | 
					
						
							|  |  |  |         provider_records: list[Provider], | 
					
						
							|  |  |  |         provider_model_records: list[ProviderModel], | 
					
						
							|  |  |  |     ) -> CustomConfiguration: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Convert to custom configuration. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |         :param tenant_id: workspace id | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         :param provider_entity: provider entity | 
					
						
							|  |  |  |         :param provider_records: provider records | 
					
						
							|  |  |  |         :param provider_model_records: provider model records | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get provider credential secret variables | 
					
						
							|  |  |  |         provider_credential_secret_variables = self._extract_secret_variables( | 
					
						
							|  |  |  |             provider_entity.provider_credential_schema.credential_form_schemas | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if provider_entity.provider_credential_schema | 
					
						
							|  |  |  |             else [] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get custom provider record | 
					
						
							|  |  |  |         custom_provider_record = None | 
					
						
							|  |  |  |         for provider_record in provider_records: | 
					
						
							|  |  |  |             if provider_record.provider_type == ProviderType.SYSTEM.value: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not provider_record.encrypted_config: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             custom_provider_record = provider_record | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get custom provider credentials | 
					
						
							|  |  |  |         custom_provider_configuration = None | 
					
						
							|  |  |  |         if custom_provider_record: | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             provider_credentials_cache = ProviderCredentialsCache( | 
					
						
							|  |  |  |                 tenant_id=tenant_id, | 
					
						
							|  |  |  |                 identity_id=custom_provider_record.id, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 cache_type=ProviderCredentialsCacheType.PROVIDER, | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             # Get cached provider credentials | 
					
						
							|  |  |  |             cached_provider_credentials = provider_credentials_cache.get() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not cached_provider_credentials: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     # fix origin data | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     if ( | 
					
						
							|  |  |  |                         custom_provider_record.encrypted_config | 
					
						
							|  |  |  |                         and not custom_provider_record.encrypted_config.startswith("{") | 
					
						
							|  |  |  |                     ): | 
					
						
							|  |  |  |                         provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                     else: | 
					
						
							|  |  |  |                         provider_credentials = json.loads(custom_provider_record.encrypted_config) | 
					
						
							|  |  |  |                 except JSONDecodeError: | 
					
						
							|  |  |  |                     provider_credentials = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Get decoding rsa key and cipher for decrypting credentials | 
					
						
							| 
									
										
										
										
											2024-01-05 12:13:45 +08:00
										 |  |  |                 if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | 
					
						
							|  |  |  |                     self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 for variable in provider_credential_secret_variables: | 
					
						
							|  |  |  |                     if variable in provider_credentials: | 
					
						
							|  |  |  |                         try: | 
					
						
							|  |  |  |                             provider_credentials[variable] = encrypter.decrypt_token_with_decoding( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                 provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                             ) | 
					
						
							|  |  |  |                         except ValueError: | 
					
						
							|  |  |  |                             pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # cache provider credentials | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 provider_credentials_cache.set(credentials=provider_credentials) | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 provider_credentials = cached_provider_credentials | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             custom_provider_configuration = CustomProviderConfiguration(credentials=provider_credentials) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Get provider model credential secret variables | 
					
						
							|  |  |  |         model_credential_secret_variables = self._extract_secret_variables( | 
					
						
							|  |  |  |             provider_entity.model_credential_schema.credential_form_schemas | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if provider_entity.model_credential_schema | 
					
						
							|  |  |  |             else [] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Get custom provider model credentials | 
					
						
							|  |  |  |         custom_model_configurations = [] | 
					
						
							|  |  |  |         for provider_model_record in provider_model_records: | 
					
						
							|  |  |  |             if not provider_model_record.encrypted_config: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             provider_model_credentials_cache = ProviderCredentialsCache( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 tenant_id=tenant_id, identity_id=provider_model_record.id, cache_type=ProviderCredentialsCacheType.MODEL | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             # Get cached provider model credentials | 
					
						
							|  |  |  |             cached_provider_model_credentials = provider_model_credentials_cache.get() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not cached_provider_model_credentials: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     provider_model_credentials = json.loads(provider_model_record.encrypted_config) | 
					
						
							|  |  |  |                 except JSONDecodeError: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Get decoding rsa key and cipher for decrypting credentials | 
					
						
							| 
									
										
										
										
											2024-01-05 12:13:45 +08:00
										 |  |  |                 if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | 
					
						
							|  |  |  |                     self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 for variable in model_credential_secret_variables: | 
					
						
							|  |  |  |                     if variable in provider_model_credentials: | 
					
						
							|  |  |  |                         try: | 
					
						
							|  |  |  |                             provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( | 
					
						
							|  |  |  |                                 provider_model_credentials.get(variable), | 
					
						
							| 
									
										
										
										
											2024-01-05 12:13:45 +08:00
										 |  |  |                                 self.decoding_rsa_key, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                 self.decoding_cipher_rsa, | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                             ) | 
					
						
							|  |  |  |                         except ValueError: | 
					
						
							|  |  |  |                             pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # cache provider model credentials | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                 provider_model_credentials_cache.set(credentials=provider_model_credentials) | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 provider_model_credentials = cached_provider_model_credentials | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             custom_model_configurations.append( | 
					
						
							|  |  |  |                 CustomModelConfiguration( | 
					
						
							|  |  |  |                     model=provider_model_record.model_name, | 
					
						
							|  |  |  |                     model_type=ModelType.value_of(provider_model_record.model_type), | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     credentials=provider_model_credentials, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         return CustomConfiguration(provider=custom_provider_configuration, models=custom_model_configurations) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def _to_system_configuration( | 
					
						
							|  |  |  |         self, tenant_id: str, provider_entity: ProviderEntity, provider_records: list[Provider] | 
					
						
							|  |  |  |     ) -> SystemConfiguration: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Convert to system configuration. | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |         :param tenant_id: workspace id | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         :param provider_entity: provider entity | 
					
						
							|  |  |  |         :param provider_records: provider records | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get hosting configuration | 
					
						
							|  |  |  |         hosting_configuration = ext_hosting_provider.hosting_configuration | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         if ( | 
					
						
							|  |  |  |             provider_entity.provider not in hosting_configuration.provider_map | 
					
						
							|  |  |  |             or not hosting_configuration.provider_map.get(provider_entity.provider).enabled | 
					
						
							|  |  |  |         ): | 
					
						
							|  |  |  |             return SystemConfiguration(enabled=False) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         provider_hosting_configuration = hosting_configuration.provider_map.get(provider_entity.provider) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Convert provider_records to dict | 
					
						
							| 
									
										
										
										
											2024-06-27 11:21:31 +08:00
										 |  |  |         quota_type_to_provider_records_dict = {} | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         for provider_record in provider_records: | 
					
						
							|  |  |  |             if provider_record.provider_type != ProviderType.SYSTEM.value: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             quota_type_to_provider_records_dict[ProviderQuotaType.value_of(provider_record.quota_type)] = ( | 
					
						
							|  |  |  |                 provider_record | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         quota_configurations = [] | 
					
						
							|  |  |  |         for provider_quota in provider_hosting_configuration.quotas: | 
					
						
							|  |  |  |             if provider_quota.quota_type not in quota_type_to_provider_records_dict: | 
					
						
							| 
									
										
										
										
											2024-01-17 15:02:27 +08:00
										 |  |  |                 if provider_quota.quota_type == ProviderQuotaType.FREE: | 
					
						
							|  |  |  |                     quota_configuration = QuotaConfiguration( | 
					
						
							|  |  |  |                         quota_type=provider_quota.quota_type, | 
					
						
							|  |  |  |                         quota_unit=provider_hosting_configuration.quota_unit, | 
					
						
							|  |  |  |                         quota_used=0, | 
					
						
							|  |  |  |                         quota_limit=0, | 
					
						
							|  |  |  |                         is_valid=False, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         restrict_models=provider_quota.restrict_models, | 
					
						
							| 
									
										
										
										
											2024-01-17 15:02:27 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 quota_configuration = QuotaConfiguration( | 
					
						
							|  |  |  |                     quota_type=provider_quota.quota_type, | 
					
						
							|  |  |  |                     quota_unit=provider_hosting_configuration.quota_unit, | 
					
						
							|  |  |  |                     quota_used=provider_record.quota_used, | 
					
						
							|  |  |  |                     quota_limit=provider_record.quota_limit, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     is_valid=provider_record.quota_limit > provider_record.quota_used | 
					
						
							|  |  |  |                     or provider_record.quota_limit == -1, | 
					
						
							|  |  |  |                     restrict_models=provider_quota.restrict_models, | 
					
						
							| 
									
										
										
										
											2024-01-17 15:02:27 +08:00
										 |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             quota_configurations.append(quota_configuration) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if len(quota_configurations) == 0: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             return SystemConfiguration(enabled=False) | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         current_quota_type = self._choice_current_using_quota_type(quota_configurations) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         current_using_credentials = provider_hosting_configuration.credentials | 
					
						
							|  |  |  |         if current_quota_type == ProviderQuotaType.FREE: | 
					
						
							|  |  |  |             provider_record = quota_type_to_provider_records_dict.get(current_quota_type) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if provider_record: | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                 provider_credentials_cache = ProviderCredentialsCache( | 
					
						
							|  |  |  |                     tenant_id=tenant_id, | 
					
						
							|  |  |  |                     identity_id=provider_record.id, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     cache_type=ProviderCredentialsCacheType.PROVIDER, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                 # Get cached provider credentials | 
					
						
							|  |  |  |                 cached_provider_credentials = provider_credentials_cache.get() | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                 if not cached_provider_credentials: | 
					
						
							|  |  |  |                     try: | 
					
						
							|  |  |  |                         provider_credentials = json.loads(provider_record.encrypted_config) | 
					
						
							|  |  |  |                     except JSONDecodeError: | 
					
						
							|  |  |  |                         provider_credentials = {} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     # Get provider credential secret variables | 
					
						
							|  |  |  |                     provider_credential_secret_variables = self._extract_secret_variables( | 
					
						
							|  |  |  |                         provider_entity.provider_credential_schema.credential_form_schemas | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         if provider_entity.provider_credential_schema | 
					
						
							|  |  |  |                         else [] | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     # Get decoding rsa key and cipher for decrypting credentials | 
					
						
							| 
									
										
										
										
											2024-01-05 12:13:45 +08:00
										 |  |  |                     if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | 
					
						
							|  |  |  |                         self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding(tenant_id) | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     for variable in provider_credential_secret_variables: | 
					
						
							|  |  |  |                         if variable in provider_credentials: | 
					
						
							|  |  |  |                             try: | 
					
						
							|  |  |  |                                 provider_credentials[variable] = encrypter.decrypt_token_with_decoding( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                     provider_credentials.get(variable), self.decoding_rsa_key, self.decoding_cipher_rsa | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                                 ) | 
					
						
							|  |  |  |                             except ValueError: | 
					
						
							|  |  |  |                                 pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     current_using_credentials = provider_credentials | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     # cache provider credentials | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     provider_credentials_cache.set(credentials=current_using_credentials) | 
					
						
							| 
									
										
										
										
											2024-01-04 20:48:54 +08:00
										 |  |  |                 else: | 
					
						
							|  |  |  |                     current_using_credentials = cached_provider_credentials | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |             else: | 
					
						
							|  |  |  |                 current_using_credentials = {} | 
					
						
							| 
									
										
										
										
											2024-01-17 15:02:27 +08:00
										 |  |  |                 quota_configurations = [] | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         return SystemConfiguration( | 
					
						
							|  |  |  |             enabled=True, | 
					
						
							|  |  |  |             current_quota_type=current_quota_type, | 
					
						
							|  |  |  |             quota_configurations=quota_configurations, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             credentials=current_using_credentials, | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 20:45:03 +09:00
										 |  |  |     @staticmethod | 
					
						
							|  |  |  |     def _choice_current_using_quota_type(quota_configurations: list[QuotaConfiguration]) -> ProviderQuotaType: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Choice current using quota type. | 
					
						
							|  |  |  |         paid quotas > provider free quotas > hosting trial quotas | 
					
						
							|  |  |  |         If there is still quota for the corresponding quota type according to the sorting, | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param quota_configurations: | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # convert to dict | 
					
						
							|  |  |  |         quota_type_to_quota_configuration_dict = { | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             quota_configuration.quota_type: quota_configuration for quota_configuration in quota_configurations | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         last_quota_configuration = None | 
					
						
							|  |  |  |         for quota_type in [ProviderQuotaType.PAID, ProviderQuotaType.FREE, ProviderQuotaType.TRIAL]: | 
					
						
							|  |  |  |             if quota_type in quota_type_to_quota_configuration_dict: | 
					
						
							|  |  |  |                 last_quota_configuration = quota_type_to_quota_configuration_dict[quota_type] | 
					
						
							|  |  |  |                 if last_quota_configuration.is_valid: | 
					
						
							|  |  |  |                     return quota_type | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if last_quota_configuration: | 
					
						
							|  |  |  |             return last_quota_configuration.quota_type | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |         raise ValueError("No quota type available") | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-08-26 20:45:03 +09:00
										 |  |  |     @staticmethod | 
					
						
							|  |  |  |     def _extract_secret_variables(credential_form_schemas: list[CredentialFormSchema]) -> list[str]: | 
					
						
							| 
									
										
										
										
											2024-01-02 23:42:00 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Extract secret input form variables. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         :param credential_form_schemas: | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         secret_input_form_variables = [] | 
					
						
							|  |  |  |         for credential_form_schema in credential_form_schemas: | 
					
						
							|  |  |  |             if credential_form_schema.type == FormType.SECRET_INPUT: | 
					
						
							|  |  |  |                 secret_input_form_variables.append(credential_form_schema.variable) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return secret_input_form_variables | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |     def _to_model_settings( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         provider_entity: ProviderEntity, | 
					
						
							|  |  |  |         provider_model_settings: Optional[list[ProviderModelSetting]] = None, | 
					
						
							|  |  |  |         load_balancing_model_configs: Optional[list[LoadBalancingModelConfig]] = None, | 
					
						
							|  |  |  |     ) -> list[ModelSettings]: | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         Convert to model settings. | 
					
						
							| 
									
										
										
										
											2024-08-16 13:19:01 +07:00
										 |  |  |         :param provider_entity: provider entity | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         :param provider_model_settings: provider model settings include enabled, load balancing enabled | 
					
						
							|  |  |  |         :param load_balancing_model_configs: load balancing model configs | 
					
						
							|  |  |  |         :return: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # Get provider model credential secret variables | 
					
						
							|  |  |  |         model_credential_secret_variables = self._extract_secret_variables( | 
					
						
							|  |  |  |             provider_entity.model_credential_schema.credential_form_schemas | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |             if provider_entity.model_credential_schema | 
					
						
							|  |  |  |             else [] | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         model_settings = [] | 
					
						
							|  |  |  |         if not provider_model_settings: | 
					
						
							|  |  |  |             return model_settings | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         for provider_model_setting in provider_model_settings: | 
					
						
							|  |  |  |             load_balancing_configs = [] | 
					
						
							|  |  |  |             if provider_model_setting.load_balancing_enabled and load_balancing_model_configs: | 
					
						
							|  |  |  |                 for load_balancing_model_config in load_balancing_model_configs: | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     if ( | 
					
						
							|  |  |  |                         load_balancing_model_config.model_name == provider_model_setting.model_name | 
					
						
							|  |  |  |                         and load_balancing_model_config.model_type == provider_model_setting.model_type | 
					
						
							|  |  |  |                     ): | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                         if not load_balancing_model_config.enabled: | 
					
						
							|  |  |  |                             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if not load_balancing_model_config.encrypted_config: | 
					
						
							|  |  |  |                             if load_balancing_model_config.name == "__inherit__": | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                 load_balancing_configs.append( | 
					
						
							|  |  |  |                                     ModelLoadBalancingConfiguration( | 
					
						
							|  |  |  |                                         id=load_balancing_model_config.id, | 
					
						
							|  |  |  |                                         name=load_balancing_model_config.name, | 
					
						
							|  |  |  |                                         credentials={}, | 
					
						
							|  |  |  |                                     ) | 
					
						
							|  |  |  |                                 ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                             continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         provider_model_credentials_cache = ProviderCredentialsCache( | 
					
						
							|  |  |  |                             tenant_id=load_balancing_model_config.tenant_id, | 
					
						
							|  |  |  |                             identity_id=load_balancing_model_config.id, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             cache_type=ProviderCredentialsCacheType.LOAD_BALANCING_MODEL, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # Get cached provider model credentials | 
					
						
							|  |  |  |                         cached_provider_model_credentials = provider_model_credentials_cache.get() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         if not cached_provider_model_credentials: | 
					
						
							|  |  |  |                             try: | 
					
						
							|  |  |  |                                 provider_model_credentials = json.loads(load_balancing_model_config.encrypted_config) | 
					
						
							|  |  |  |                             except JSONDecodeError: | 
					
						
							|  |  |  |                                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                             # Get decoding rsa key and cipher for decrypting credentials | 
					
						
							|  |  |  |                             if self.decoding_rsa_key is None or self.decoding_cipher_rsa is None: | 
					
						
							|  |  |  |                                 self.decoding_rsa_key, self.decoding_cipher_rsa = encrypter.get_decrypt_decoding( | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                     load_balancing_model_config.tenant_id | 
					
						
							|  |  |  |                                 ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                             for variable in model_credential_secret_variables: | 
					
						
							|  |  |  |                                 if variable in provider_model_credentials: | 
					
						
							|  |  |  |                                     try: | 
					
						
							|  |  |  |                                         provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding( | 
					
						
							|  |  |  |                                             provider_model_credentials.get(variable), | 
					
						
							|  |  |  |                                             self.decoding_rsa_key, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                                             self.decoding_cipher_rsa, | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                                         ) | 
					
						
							|  |  |  |                                     except ValueError: | 
					
						
							|  |  |  |                                         pass | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                             # cache provider model credentials | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                             provider_model_credentials_cache.set(credentials=provider_model_credentials) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                         else: | 
					
						
							|  |  |  |                             provider_model_credentials = cached_provider_model_credentials | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                         load_balancing_configs.append( | 
					
						
							|  |  |  |                             ModelLoadBalancingConfiguration( | 
					
						
							|  |  |  |                                 id=load_balancing_model_config.id, | 
					
						
							|  |  |  |                                 name=load_balancing_model_config.name, | 
					
						
							|  |  |  |                                 credentials=provider_model_credentials, | 
					
						
							|  |  |  |                             ) | 
					
						
							|  |  |  |                         ) | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             model_settings.append( | 
					
						
							|  |  |  |                 ModelSettings( | 
					
						
							|  |  |  |                     model=provider_model_setting.model_name, | 
					
						
							|  |  |  |                     model_type=ModelType.value_of(provider_model_setting.model_type), | 
					
						
							|  |  |  |                     enabled=provider_model_setting.enabled, | 
					
						
							| 
									
										
										
										
											2024-09-10 17:00:20 +08:00
										 |  |  |                     load_balancing_configs=load_balancing_configs if len(load_balancing_configs) > 1 else [], | 
					
						
							| 
									
										
										
										
											2024-06-05 00:13:04 +08:00
										 |  |  |                 ) | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return model_settings |