import json import logging import re from pathlib import Path from typing import Any, Optional from sqlalchemy.orm import Session from configs import dify_config from core.helper.position_helper import is_filtered from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ( ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity, ToolProviderCredentialInfoApiEntity, ) from core.tools.entities.tool_entities import CredentialType from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) class BuiltinToolManageService: __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 @staticmethod def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str): """ get builtin tool provider oauth client schema """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) return { "schema": provider.get_oauth_client_schema(), "is_oauth_custom_client_enabled": BuiltinToolManageService.is_oauth_custom_client_enabled( tenant_id, provider_name ), } @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ list builtin tool provider tools :param tenant_id: the id of the tenant :param provider: the name of the provider :return: the list of tools """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() result: list[ToolApiEntity] = [] for tool in tools or []: result.append( ToolTransformService.convert_tool_entity_to_api_entity( tool=tool, tenant_id=tenant_id, labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) return result @staticmethod def get_builtin_tool_provider_info(tenant_id: str, provider: str): """ get builtin tool provider info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) # check if user has added the provider builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id) if builtin_provider is None: raise ValueError(f"you have not added provider {provider}") entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=builtin_provider, decrypt_credentials=True, ) entity.original_credentials = {} return entity @staticmethod def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str): """ list builtin provider credentials schema :param credential_type: credential type :param provider_name: the name of the provider :param tenant_id: the id of the tenant :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) return provider.get_credentials_schema_by_type(credential_type) @staticmethod def update_builtin_tool_provider( user_id: str, tenant_id: str, provider: str, credentials: dict, credential_id: str, name: str | None = None ): """ update builtin tool provider """ # get if the provider exists db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) if db_provider is None: raise ValueError(f"you have not added provider {provider}") try: if CredentialType.of(db_provider.credential_type).is_editable(): provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider} does not need credentials") encrypter, cache = BuiltinToolManageService.create_tool_encrypter( tenant_id, db_provider, provider, provider_controller ) # Decrypt and restore original credentials for masked values original_credentials = encrypter.decrypt(db_provider.credentials) masked_credentials = encrypter.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for key, value in credentials.items(): if key in masked_credentials and value == masked_credentials[key]: credentials[key] = original_credentials[key] if CredentialType.of(db_provider.credential_type).is_validate_allowed(): provider_controller.validate_credentials(user_id, credentials) # encrypt credentials db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials)) cache.delete() # update name if provided if name is not None and db_provider.name != name: db_provider.name = name db.session.commit() except ( PluginDaemonClientSideError, ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError, ) as e: db.session.rollback() raise ValueError(str(e)) return {"result": "success"} @staticmethod def add_builtin_tool_provider( user_id: str, api_type: CredentialType, tenant_id: str, provider: str, credentials: dict, name: str | None = None, ): """ add builtin tool provider """ lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" try: with redis_client.lock(lock, timeout=20): provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller.need_credentials: raise ValueError(f"provider {provider} does not need credentials") provider_count = ( db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() ) # check if the provider count is reached the limit if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: raise ValueError(f"you have reached the maximum number of providers for {provider}") # validate credentials if allowed if CredentialType.of(api_type).is_validate_allowed(): provider_controller.validate_credentials(user_id, credentials) # generate name if not provided if name is None: name = BuiltinToolManageService.generate_builtin_tool_provider_name( tenant_id=tenant_id, provider=provider, credential_type=api_type ) # create encrypter encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() for x in provider_controller.get_credentials_schema_by_type(api_type) ], cache=NoOpProviderCredentialCache(), ) db_provider = BuiltinToolProvider( tenant_id=tenant_id, user_id=user_id, provider=provider, encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), credential_type=api_type.value, name=name, ) db.session.add(db_provider) db.session.commit() except ( PluginDaemonClientSideError, ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError, ) as e: db.session.rollback() raise ValueError(str(e)) return {"result": "success"} @staticmethod def create_tool_encrypter( tenant_id: str, db_provider: BuiltinToolProvider, provider: str, provider_controller: BuiltinToolProviderController, ): encrypter, cache = create_provider_encrypter( tenant_id=tenant_id, config=[ x.to_basic_provider_config() for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type) ], cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id), ) return encrypter, cache @staticmethod def generate_builtin_tool_provider_name(tenant_id: str, provider: str, credential_type: CredentialType) -> str: try: db_providers = ( db.session.query(BuiltinToolProvider) .filter_by( tenant_id=tenant_id, provider=provider, credential_type=credential_type.value, ) .order_by(BuiltinToolProvider.created_at.desc()) .all() ) # Get the default name pattern default_pattern = f"{credential_type.get_name()}" # Find all names that match the default pattern: "{default_pattern} {number}" pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" numbers = [] for db_provider in db_providers: if db_provider.name: match = re.match(pattern, db_provider.name.strip()) if match: numbers.append(int(match.group(1))) # If no default pattern names found, start with 1 if not numbers: return f"{default_pattern} 1" # Find the next number max_number = max(numbers) return f"{default_pattern} {max_number + 1}" except Exception as e: logger.warning(f"Error generating next provider name for {provider}: {str(e)}") # fallback return f"{credential_type.get_name()} 1" @staticmethod def get_builtin_tool_provider_credentials( tenant_id: str, provider_name: str ) -> list[ToolProviderCredentialApiEntity]: """ get builtin tool provider credentials """ with db.session.no_autoflush: providers = ( db.session.query(BuiltinToolProvider) .filter_by(tenant_id=tenant_id, provider=provider_name) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .all() ) if len(providers) == 0: return [] default_provider = providers[0] default_provider.is_default = True provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) encrypter, cache = BuiltinToolManageService.create_tool_encrypter( tenant_id, default_provider, default_provider.provider, provider_controller ) credentials: list[ToolProviderCredentialApiEntity] = [] for provider in providers: decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials)) credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( provider=provider, credentials=decrypt_credential, ) credentials.append(credential_entity) return credentials @staticmethod def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity: """ get builtin tool provider credential info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) supported_credential_types = provider_controller.get_supported_credential_types() credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider) credential_info = ToolProviderCredentialInfoApiEntity( supported_credential_types=supported_credential_types, is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), credentials=credentials, ) return credential_info @staticmethod def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str): """ delete tool provider """ tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id) if tool_provider is None: raise ValueError(f"you have not added provider {provider}") db.session.delete(tool_provider) db.session.commit() # delete cache provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) _, cache = BuiltinToolManageService.create_tool_encrypter( tenant_id, tool_provider, provider, provider_controller ) cache.delete() return {"result": "success"} @staticmethod def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str): """ set default provider """ with Session(db.engine) as session: # get provider target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first() if target_provider is None: raise ValueError("provider not found") # clear default provider session.query(BuiltinToolProvider).filter_by( tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True ).update({"is_default": False}) # set new default provider target_provider.is_default = True session.commit() return {"result": "success"} @staticmethod def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool: """ check if oauth custom client is enabled """ tool_provider = ToolProviderID(provider) with Session(db.engine).no_autoflush as session: user_client: ToolOAuthTenantClient | None = ( session.query(ToolOAuthTenantClient) .filter_by( tenant_id=tenant_id, provider=tool_provider.provider_name, plugin_id=tool_provider.plugin_id, enabled=True, ) .first() ) return user_client is not None and user_client.enabled @staticmethod def get_oauth_client(tenant_id: str, provider: str) -> dict[str, Any] | None: """ get builtin tool provider """ tool_provider = ToolProviderID(provider) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) with Session(db.engine).no_autoflush as session: user_client: ToolOAuthTenantClient | None = ( session.query(ToolOAuthTenantClient) .filter_by( tenant_id=tenant_id, provider=tool_provider.provider_name, plugin_id=tool_provider.plugin_id, enabled=True, ) .first() ) oauth_params: dict[str, Any] | None = None if user_client: oauth_params = encrypter.decrypt(user_client.oauth_params) return oauth_params system_client: ToolOAuthSystemClient | None = ( session.query(ToolOAuthSystemClient) .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) .first() ) if system_client: oauth_params = encrypter.decrypt(system_client.oauth_params) return oauth_params @staticmethod def get_builtin_tool_provider_icon(provider: str): """ get tool provider icon and it's mimetype """ icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider) icon_bytes = Path(icon_path).read_bytes() return icon_bytes, mime_type @staticmethod def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: """ list builtin tools """ # get all builtin providers provider_controllers = ToolManager.list_builtin_providers(tenant_id) with db.session.no_autoflush: # get all user added providers db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) # rewrite db_providers for db_provider in db_providers: db_provider.provider = str(ToolProviderID(db_provider.provider)) # find provider def find_provider(provider): return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) result: list[ToolProviderApiEntity] = [] for provider_controller in provider_controllers: try: # handle include, exclude if is_filtered( include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore data=provider_controller, name_func=lambda x: x.identity.name, ): continue # convert provider controller to user provider user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, db_provider=find_provider(provider_controller.entity.identity.name), decrypt_credentials=True, ) # add icon ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) tools = provider_controller.get_tools() for tool in tools or []: user_builtin_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) result.append(user_builtin_provider) except Exception as e: raise e return BuiltinToolProviderSort.sort(result) @staticmethod def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]: provider: Optional[BuiltinToolProvider] = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.id == credential_id, ) .first() ) return provider @staticmethod def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: """ This method is used to fetch the builtin provider from the database 1.if the default provider exists, return the default provider 2.if the default provider does not exist, return the oldest provider """ with Session(db.engine) as session: try: full_provider_name = provider_name provider_id_entity = ToolProviderID(provider_name) provider_name = provider_id_entity.provider_name if provider_id_entity.organization != "langgenius": provider = ( session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == full_provider_name, ) .order_by( BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) .first() ) else: provider = ( session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_name) | (BuiltinToolProvider.provider == full_provider_name), ) .order_by( BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) .first() ) if provider is None: return None provider.provider = ToolProviderID(provider.provider).to_string() return provider except Exception: # it's an old provider without organization return ( session.query(BuiltinToolProvider) .filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) .order_by( BuiltinToolProvider.is_default.desc(), # default=True first BuiltinToolProvider.created_at.asc(), # oldest first ) .first() ) @staticmethod def save_custom_oauth_client_params( tenant_id: str, provider: str, client_params: Optional[dict] = None, enable_oauth_custom_client: Optional[bool] = None, ): """ setup oauth custom client """ if client_params is None and enable_oauth_custom_client is None: return {"result": "success"} tool_provider = ToolProviderID(provider) provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller: raise ToolProviderNotFoundError(f"Provider {provider} not found") if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): raise ValueError(f"Provider {provider} is not a builtin or plugin provider") with Session(db.engine) as session: custom_client_params = ( session.query(ToolOAuthTenantClient) .filter_by( tenant_id=tenant_id, plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name, ) .first() ) # if the record does not exist, create a basic record if custom_client_params is None: custom_client_params = ToolOAuthTenantClient( tenant_id=tenant_id, plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name, ) session.add(custom_client_params) if client_params is not None: encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(client_params)) if enable_oauth_custom_client is not None: custom_client_params.enabled = enable_oauth_custom_client session.commit() return {"result": "success"} @staticmethod def get_custom_oauth_client_params(tenant_id: str, provider: str): """ get custom oauth client params """ with Session(db.engine) as session: tool_provider = ToolProviderID(provider) custom_oauth_client_params: ToolOAuthTenantClient | None = ( session.query(ToolOAuthTenantClient) .filter_by( tenant_id=tenant_id, plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name, ) .first() ) if custom_oauth_client_params is None: return {} provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) if not provider_controller: raise ToolProviderNotFoundError(f"Provider {provider} not found") if not isinstance(provider_controller, BuiltinToolProviderController): raise ValueError(f"Provider {provider} is not a builtin or plugin provider") encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))