dify/api/services/datasource_provider_service.py
2025-07-28 19:29:36 +08:00

721 lines
30 KiB
Python

import logging
from typing import Any
from flask_login import current_user
from sqlalchemy.orm import Session
from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper import encrypter
from core.helper.name_generator import generate_incremental_name
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.model_runtime.entities.provider_entities import FormType
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import CredentialType
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
logger = logging.getLogger(__name__)
class DatasourceProviderService:
"""
Model Provider Service
"""
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
"""
remove oauth custom client params
"""
with Session(db.engine) as session:
session.query(DatasourceOauthTenantParamConfig).filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
).delete()
session.commit()
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
"""
get default credentials
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
if not datasource_provider:
return {}
return datasource_provider.encrypted_credentials
def get_real_credential_by_id(
self, tenant_id: str, credential_id: str, provider: str, plugin_id: str
) -> dict[str, Any]:
"""
get credential by id
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
)
if not datasource_provider:
return {}
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
return copy_credentials
def update_datasource_provider_name(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
):
"""
update datasource provider name
"""
with Session(db.engine) as session:
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if target_provider is None:
raise ValueError("provider not found")
if target_provider.name == name:
return
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=name,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
target_provider.name = name
session.commit()
return
def set_default_datasource_provider(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
):
"""
set default datasource provider
"""
with Session(db.engine) as session:
# get provider
target_provider = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
id=credential_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=target_provider.provider,
plugin_id=target_provider.plugin_id,
is_default=True,
).update({"is_default": False})
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
def setup_oauth_custom_client_params(
self,
tenant_id: str,
datasource_provider_id: DatasourceProviderID,
client_params: dict | None,
enabled: bool | None,
):
"""
setup oauth custom client params
"""
if client_params is None and enabled is None:
return
with Session(db.engine) as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if not tenant_oauth_client_params:
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
client_params={},
enabled=False,
)
session.add(tenant_oauth_client_params)
if client_params is not None:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
original_params = (
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
)
new_params: dict = {
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
for key, value in client_params.items()
}
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
if enabled is not None:
tenant_oauth_client_params.enabled = enabled
session.commit()
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if system oauth params exist
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthParamConfig)
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
.first()
is not None
)
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
"""
check if tenant oauth params is enabled
"""
with Session(db.engine).no_autoflush as session:
return (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
enabled=True,
)
.count()
> 0
)
def get_tenant_oauth_client(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
) -> dict[str, Any] | None:
"""
get tenant oauth client
"""
with Session(db.engine).no_autoflush as session:
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=datasource_provider_id.provider_name,
plugin_id=datasource_provider_id.plugin_id,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
if mask:
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
else:
return encrypter.decrypt(tenant_oauth_client_params.client_params)
return None
def get_oauth_encrypter(
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
"""
get oauth encrypter
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
)
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
client_schema = datasource_provider.declaration.oauth_schema.client_schema
return create_provider_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in client_schema],
cache=NoOpProviderCredentialCache(),
)
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
"""
get oauth client
"""
provider = datasource_provider_id.provider_name
plugin_id = datasource_provider_id.plugin_id
with Session(db.engine).no_autoflush as session:
# get tenant oauth client params
tenant_oauth_client_params = (
session.query(DatasourceOauthTenantParamConfig)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
)
.first()
)
if tenant_oauth_client_params:
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
return encrypter.decrypt(tenant_oauth_client_params.client_params)
# fallback to system oauth client params
oauth_client_params = (
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
)
if oauth_client_params:
return oauth_client_params.system_credentials
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
@staticmethod
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
) -> str:
db_providers = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
.all()
)
return generate_incremental_name(
[provider.name for provider in db_providers],
f"{credential_type.get_name()}",
)
def add_datasource_oauth_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
avatar_url: str | None,
credentials: dict,
) -> None:
"""
add datasource oauth provider
"""
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name
if not db_provider_name:
db_provider_name = self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
else:
if (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
)
.count()
> 0
):
db_provider_name = generate_incremental_name(
[
provider.name
for provider in session.query(DatasourceProvider).filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
],
db_provider_name,
)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
encrypted_credentials=credentials,
avatar_url=avatar_url or "default",
)
session.add(datasource_provider)
session.commit()
def add_datasource_api_key_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
credentials: dict,
) -> None:
"""
validate datasource provider credentials.
:param tenant_id:
:param provider:
:param credentials:
"""
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=CredentialType.API_KEY,
)
# check name is exist
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
try:
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider_name,
plugin_id=plugin_id,
credentials=credentials,
)
except Exception as e:
raise ValueError(f"Failed to validate credentials: {str(e)}")
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_name,
plugin_id=plugin_id,
auth_type=CredentialType.API_KEY.value,
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(
tenant_id=tenant_id, provider_id=provider_id
)
credential_form_schemas = []
if credential_type == CredentialType.API_KEY:
credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
elif credential_type == CredentialType.OAUTH2:
if not datasource_provider.declaration.oauth_schema:
raise ValueError("Datasource provider oauth schema not found")
credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
else:
raise ValueError(f"Invalid credential type: {credential_type}")
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
secret_input_form_variables.append(credential_form_schema.name)
return secret_input_form_variables
def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
get datasource credentials.
:param tenant_id: workspace id
:param provider_id: provider id
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
default_provider = (
db.session.query(DatasourceProvider.id)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
.first()
)
default_provider_id = default_provider.id if default_provider else None
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.obfuscated_token(value)
copy_credentials_list.append(
{
"credential": copy_credentials,
"type": datasource_provider.auth_type,
"name": datasource_provider.name,
"avatar_url": datasource_provider.avatar_url,
"id": datasource_provider.id,
"is_default": default_provider_id and datasource_provider.id == default_provider_id,
}
)
return copy_credentials_list
def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
"""
get datasource credentials.
:return:
"""
# get all plugin providers
manager = PluginDatasourceManager()
datasources = manager.fetch_installed_datasource_providers(tenant_id)
datasource_credentials = []
for datasource in datasources:
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
credentials = self.get_datasource_credentials(
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
)
redirect_uri = (
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
)
datasource_credentials.append(
{
"provider": datasource.provider,
"plugin_id": datasource.plugin_id,
"plugin_unique_identifier": datasource.plugin_unique_identifier,
"icon": datasource.declaration.identity.icon,
"name": datasource.declaration.identity.name.split("/")[-1],
"label": datasource.declaration.identity.label.model_dump(),
"description": datasource.declaration.identity.description.model_dump(),
"author": datasource.declaration.identity.author,
"credentials_list": credentials,
"credential_schema": [
credential.model_dump() for credential in datasource.declaration.credentials_schema
],
"oauth_schema": {
"client_schema": [
client_schema.model_dump()
for client_schema in datasource.declaration.oauth_schema.client_schema
],
"credentials_schema": [
credential_schema.model_dump()
for credential_schema in datasource.declaration.oauth_schema.credentials_schema
],
"oauth_custom_client_params": self.get_tenant_oauth_client(
tenant_id, datasource_provider_id, mask=True
),
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
tenant_id, datasource_provider_id
),
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
"redirect_uri": redirect_uri,
}
if datasource.declaration.oauth_schema
else None,
}
)
return datasource_credentials
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
"""
get datasource credentials.
:param tenant_id: workspace id
:param provider_id: provider id
:return:
"""
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
)
.all()
)
if not datasource_providers:
return []
copy_credentials_list = []
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()
for key, value in copy_credentials.items():
if key in credential_secret_variables:
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
copy_credentials_list.append(
{
"credentials": copy_credentials,
"type": datasource_provider.auth_type,
}
)
return copy_credentials_list
def update_datasource_credentials(
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
) -> None:
"""
update datasource credentials.
"""
with Session(db.engine) as session:
datasource_provider = (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
# update name
if name and name != datasource_provider.name:
if (
session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
.count()
> 0
):
raise ValueError("Authorization name is already exists")
datasource_provider.name = name
# update credentials
if credentials:
secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
original_credentials = {
key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
for key, value in datasource_provider.encrypted_credentials.items()
}
new_credentials = {
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
for key, value in credentials.items()
}
try:
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=new_credentials,
)
except Exception as e:
raise ValueError(f"Failed to validate credentials: {str(e)}")
encrypted_credentials = {}
for key, value in new_credentials.items():
if key in secret_variables:
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
else:
encrypted_credentials[key] = value
datasource_provider.encrypted_credentials = encrypted_credentials
session.commit()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
"""
remove datasource credentials.
:param tenant_id: workspace id
:param provider: provider name
:param plugin_id: plugin id
:return:
"""
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
if datasource_provider:
db.session.delete(datasource_provider)
db.session.commit()