import logging from typing import Any from flask_login import current_user from sqlalchemy.orm import Session from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper import encrypter from core.helper.name_generator import generate_incremental_name from core.helper.provider_cache import NoOpProviderCredentialCache from core.model_runtime.entities.provider_entities import FormType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.plugin.entities.plugin import DatasourceProviderID from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.entities.tool_entities import CredentialType from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider logger = logging.getLogger(__name__) class DatasourceProviderService: """ Model Provider Service """ def __init__(self) -> None: self.provider_manager = PluginDatasourceManager() def setup_oauth_custom_client_params( self, tenant_id: str, datasource_provider_id: DatasourceProviderID, client_params: dict | None, enabled: bool | None, ): """ setup oauth custom client params """ if client_params is None and enabled is None: return provider_controller = PluginDatasourceManager() datasource_provider = provider_controller.fetch_datasource_provider( tenant_id=tenant_id, provider_id=str(datasource_provider_id) ) if not datasource_provider.declaration.oauth_schema: raise ValueError("Datasource provider oauth schema not found") with Session(db.engine) as session: tenant_oauth_client_params = ( session.query(DatasourceOauthTenantParamConfig) .filter_by( tenant_id=tenant_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, ) .first() ) if not tenant_oauth_client_params: tenant_oauth_client_params = DatasourceOauthTenantParamConfig( tenant_id=tenant_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, client_params={}, enabled=False, ) session.add(tenant_oauth_client_params) if client_params is not None: client_schema = datasource_provider.declaration.oauth_schema.client_schema encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in client_schema], cache=NoOpProviderCredentialCache(), ) original_params = ( encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {} ) new_params: dict = { key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) for key, value in client_params.items() } tenant_oauth_client_params.client_params = encrypter.encrypt(new_params) if enabled is not None: tenant_oauth_client_params.enabled = enabled session.commit() def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool: """ check if system oauth params exist """ with Session(db.engine).no_autoflush as session: return ( session.query(DatasourceOauthParamConfig) .filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id) .first() is not None ) def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool: """ check if tenant oauth params is enabled """ with Session(db.engine).no_autoflush as session: return ( session.query(DatasourceOauthTenantParamConfig) .filter_by( tenant_id=tenant_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, enabled=True, ) .count() > 0 ) def get_tenant_oauth_client( self, tenant_id: str, datasource_provider_id: DatasourceProviderID ) -> dict[str, Any] | None: """ get tenant oauth client """ with Session(db.engine).no_autoflush as session: tenant_oauth_client_params = ( session.query(DatasourceOauthTenantParamConfig) .filter_by( tenant_id=tenant_id, provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id, ) .first() ) if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) return encrypter.decrypt(tenant_oauth_client_params.client_params) return None def get_oauth_encrypter( self, tenant_id: str, datasource_provider_id: DatasourceProviderID ) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]: """ get oauth encrypter """ datasource_provider = self.provider_manager.fetch_datasource_provider( tenant_id=tenant_id, provider_id=str(datasource_provider_id) ) if not datasource_provider.declaration.oauth_schema: raise ValueError("Datasource provider oauth schema not found") client_schema = datasource_provider.declaration.oauth_schema.client_schema return create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in client_schema], cache=NoOpProviderCredentialCache(), ) def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None: """ get oauth client """ provider = datasource_provider_id.provider_name plugin_id = datasource_provider_id.plugin_id with Session(db.engine).no_autoflush as session: # get tenant oauth client params tenant_oauth_client_params = ( session.query(DatasourceOauthTenantParamConfig) .filter_by( tenant_id=tenant_id, provider=provider, plugin_id=plugin_id, enabled=True, ) .first() ) if tenant_oauth_client_params: encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id) return encrypter.decrypt(tenant_oauth_client_params.client_params) # fallback to system oauth client params oauth_client_params = ( session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first() ) if oauth_client_params: return oauth_client_params.system_credentials raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}") @staticmethod def generate_next_datasource_provider_name( session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType ) -> str: db_providers = ( session.query(DatasourceProvider) .filter_by( tenant_id=tenant_id, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, auth_type=credential_type.value, ) .all() ) return generate_incremental_name( [provider.name for provider in db_providers], f"{credential_type.get_name()}", ) def add_datasource_oauth_provider( self, name: str | None, tenant_id: str, provider_id: DatasourceProviderID, avatar_url: str | None, credentials: dict, ) -> None: """ add datasource oauth provider """ credential_type = CredentialType.OAUTH2 with Session(db.engine) as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}" with redis_client.lock(lock, timeout=20): db_provider_name = name if not db_provider_name: db_provider_name = self.generate_next_datasource_provider_name( session=session, tenant_id=tenant_id, provider_id=provider_id, credential_type=credential_type, ) else: if ( session.query(DatasourceProvider) .filter_by( tenant_id=tenant_id, name=db_provider_name, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, auth_type=credential_type.value, ) .count() > 0 ): db_provider_name = generate_incremental_name( [ provider.name for provider in session.query(DatasourceProvider).filter_by( tenant_id=tenant_id, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, ) ], db_provider_name, ) provider_credential_secret_variables = self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{provider_id}" ) for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value credentials[key] = encrypter.encrypt_token(tenant_id, value) datasource_provider = DatasourceProvider( tenant_id=tenant_id, name=db_provider_name, provider=provider_id.provider_name, plugin_id=provider_id.plugin_id, auth_type=credential_type.value, encrypted_credentials=credentials, avatar_url=avatar_url or "default", ) session.add(datasource_provider) session.commit() def add_datasource_api_key_provider( self, name: str | None, tenant_id: str, provider_id: DatasourceProviderID, credentials: dict, ) -> None: """ validate datasource provider credentials. :param tenant_id: :param provider: :param credentials: """ provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id with Session(db.engine) as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_api_key" with redis_client.lock(lock, timeout=20): db_provider_name = name or self.generate_next_datasource_provider_name( session=session, tenant_id=tenant_id, provider_id=provider_id, credential_type=CredentialType.API_KEY, ) # check name is exist if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0: raise ValueError("Authorization name is already exists") credential_valid = self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, provider=provider_name, plugin_id=plugin_id, credentials=credentials, ) if credential_valid: provider_credential_secret_variables = self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{provider_id}" ) for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value credentials[key] = encrypter.encrypt_token(tenant_id, value) datasource_provider = DatasourceProvider( tenant_id=tenant_id, name=db_provider_name, provider=provider_name, plugin_id=plugin_id, auth_type="api_key", encrypted_credentials=credentials, ) db.session.add(datasource_provider) db.session.commit() else: raise CredentialsValidateFailedError() def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]: """ Extract secret input form variables. :param credential_form_schemas: :return: """ datasource_provider = self.provider_manager.fetch_datasource_provider( tenant_id=tenant_id, provider_id=provider_id ) credential_form_schemas = datasource_provider.declaration.credentials_schema secret_input_form_variables = [] for credential_form_schema in credential_form_schemas: if credential_form_schema.type.value == FormType.SECRET_INPUT.value: secret_input_form_variables.append(credential_form_schema.name) return secret_input_form_variables def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: """ get datasource credentials. :param tenant_id: workspace id :param provider_id: provider id :return: """ # Get all provider configurations of the current workspace datasource_providers: list[DatasourceProvider] = ( db.session.query(DatasourceProvider) .filter( DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.provider == provider, DatasourceProvider.plugin_id == plugin_id, ) .all() ) if not datasource_providers: return [] copy_credentials_list = [] for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables credential_secret_variables = self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" ) # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() for key, value in copy_credentials.items(): if key in credential_secret_variables: copy_credentials[key] = encrypter.obfuscated_token(value) copy_credentials_list.append( { "credential": copy_credentials, "type": datasource_provider.auth_type, "name": datasource_provider.name, "avatar_url": datasource_provider.avatar_url, "id": datasource_provider.id, } ) return copy_credentials_list def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]: """ get datasource credentials. :return: """ # get all plugin providers manager = PluginDatasourceManager() datasources = manager.fetch_installed_datasource_providers(tenant_id) datasource_credentials = [] for datasource in datasources: datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}") credentials = self.get_datasource_credentials( tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id ) datasource_credentials.append( { "provider": datasource.provider, "plugin_id": datasource.plugin_id, "plugin_unique_identifier": datasource.plugin_unique_identifier, "icon": datasource.declaration.identity.icon, "name": datasource.declaration.identity.name, "label": datasource.declaration.identity.label.model_dump(), "description": datasource.declaration.identity.description.model_dump(), "author": datasource.declaration.identity.author, "credentials_list": credentials, "credential_schema": [ { "type": credential.type.value, "name": credential.name, "required": credential.required, "default": credential.default, "options": [ { "value": option.value, "label": option.label.model_dump(), } for option in credential.options or [] ], } for credential in datasource.declaration.credentials_schema ], "oauth_schema": { "client_schema": [ { "type": client_schema.type.value, "name": client_schema.name, "required": client_schema.required, "default": client_schema.default, "options": [ { "value": option.value, "label": option.label.model_dump(), } for option in client_schema.options or [] ], } for client_schema in datasource.declaration.oauth_schema.client_schema or [] ], "credentials_schema": [ { "type": credential.type.value, "name": credential.name, "required": credential.required, "default": credential.default, "options": [ { "value": option.value, "label": option.label.model_dump(), } for option in credential.options or [] ], } for credential in datasource.declaration.oauth_schema.credentials_schema or [] ], "oauth_custom_client_params": self.get_tenant_oauth_client(tenant_id, datasource_provider_id), "is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled( tenant_id, datasource_provider_id ), "is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id), } if datasource.declaration.oauth_schema else None, } ) return datasource_credentials def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]: """ get datasource credentials. :param tenant_id: workspace id :param provider_id: provider id :return: """ # Get all provider configurations of the current workspace datasource_providers: list[DatasourceProvider] = ( db.session.query(DatasourceProvider) .filter( DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.provider == provider, DatasourceProvider.plugin_id == plugin_id, ) .all() ) if not datasource_providers: return [] copy_credentials_list = [] for datasource_provider in datasource_providers: encrypted_credentials = datasource_provider.encrypted_credentials # Get provider credential secret variables credential_secret_variables = self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" ) # Obfuscate provider credentials copy_credentials = encrypted_credentials.copy() for key, value in copy_credentials.items(): if key in credential_secret_variables: copy_credentials[key] = encrypter.decrypt_token(tenant_id, value) copy_credentials_list.append( { "credentials": copy_credentials, "type": datasource_provider.auth_type, } ) return copy_credentials_list def update_datasource_credentials( self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict ) -> None: """ update datasource credentials. """ credential_valid = self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, provider=provider, plugin_id=plugin_id, credentials=credentials, ) if credential_valid: # Get all provider configurations of the current workspace datasource_provider = ( db.session.query(DatasourceProvider) .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) .first() ) provider_credential_secret_variables = self.extract_secret_variables( tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" ) if not datasource_provider: raise ValueError("Datasource provider not found") else: original_credentials = datasource_provider.encrypted_credentials for key, value in credentials.items(): if key in provider_credential_secret_variables: # if send [__HIDDEN__] in secret input, it will be same as original value if value == HIDDEN_VALUE and key in original_credentials: original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) credentials[key] = encrypter.encrypt_token(tenant_id, original_value) else: credentials[key] = encrypter.encrypt_token(tenant_id, value) datasource_provider.encrypted_credentials = credentials db.session.commit() else: raise CredentialsValidateFailedError() def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: """ remove datasource credentials. :param tenant_id: workspace id :param provider: provider name :param plugin_id: plugin id :return: """ datasource_provider = ( db.session.query(DatasourceProvider) .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) .first() ) if datasource_provider: db.session.delete(datasource_provider) db.session.commit()