mirror of
https://github.com/langgenius/dify.git
synced 2025-09-26 09:36:13 +00:00
721 lines
30 KiB
Python
721 lines
30 KiB
Python
import logging
|
|
from typing import Any
|
|
|
|
from flask_login import current_user
|
|
from sqlalchemy.orm import Session
|
|
|
|
from configs import dify_config
|
|
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
|
|
from core.helper import encrypter
|
|
from core.helper.name_generator import generate_incremental_name
|
|
from core.helper.provider_cache import NoOpProviderCredentialCache
|
|
from core.model_runtime.entities.provider_entities import FormType
|
|
from core.plugin.entities.plugin import DatasourceProviderID
|
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
|
from core.tools.entities.tool_entities import CredentialType
|
|
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
|
|
from extensions.ext_database import db
|
|
from extensions.ext_redis import redis_client
|
|
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatasourceProviderService:
|
|
"""
|
|
Model Provider Service
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self.provider_manager = PluginDatasourceManager()
|
|
|
|
def remove_oauth_custom_client_params(self, tenant_id: str, datasource_provider_id: DatasourceProviderID):
|
|
"""
|
|
remove oauth custom client params
|
|
"""
|
|
with Session(db.engine) as session:
|
|
session.query(DatasourceOauthTenantParamConfig).filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=datasource_provider_id.provider_name,
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
).delete()
|
|
session.commit()
|
|
|
|
def get_default_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> dict[str, Any]:
|
|
"""
|
|
get default credentials
|
|
"""
|
|
with Session(db.engine) as session:
|
|
datasource_provider = (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
|
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
|
.first()
|
|
)
|
|
if not datasource_provider:
|
|
return {}
|
|
return datasource_provider.encrypted_credentials
|
|
|
|
def get_real_credential_by_id(
|
|
self, tenant_id: str, credential_id: str, provider: str, plugin_id: str
|
|
) -> dict[str, Any]:
|
|
"""
|
|
get credential by id
|
|
"""
|
|
with Session(db.engine) as session:
|
|
datasource_provider = (
|
|
session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first()
|
|
)
|
|
if not datasource_provider:
|
|
return {}
|
|
encrypted_credentials = datasource_provider.encrypted_credentials
|
|
# Get provider credential secret variables
|
|
credential_secret_variables = self.extract_secret_variables(
|
|
tenant_id=tenant_id,
|
|
provider_id=f"{plugin_id}/{provider}",
|
|
credential_type=CredentialType.of(datasource_provider.auth_type),
|
|
)
|
|
|
|
# Obfuscate provider credentials
|
|
copy_credentials = encrypted_credentials.copy()
|
|
for key, value in copy_credentials.items():
|
|
if key in credential_secret_variables:
|
|
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
|
return copy_credentials
|
|
|
|
def update_datasource_provider_name(
|
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, name: str, credential_id: str
|
|
):
|
|
"""
|
|
update datasource provider name
|
|
"""
|
|
with Session(db.engine) as session:
|
|
target_provider = (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
id=credential_id,
|
|
provider=datasource_provider_id.provider_name,
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
)
|
|
.first()
|
|
)
|
|
if target_provider is None:
|
|
raise ValueError("provider not found")
|
|
|
|
if target_provider.name == name:
|
|
return
|
|
|
|
# check name is exist
|
|
if (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
name=name,
|
|
provider=datasource_provider_id.provider_name,
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
)
|
|
.count()
|
|
> 0
|
|
):
|
|
raise ValueError("Authorization name is already exists")
|
|
|
|
target_provider.name = name
|
|
session.commit()
|
|
return
|
|
|
|
def set_default_datasource_provider(
|
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, credential_id: str
|
|
):
|
|
"""
|
|
set default datasource provider
|
|
"""
|
|
with Session(db.engine) as session:
|
|
# get provider
|
|
target_provider = (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
id=credential_id,
|
|
provider=datasource_provider_id.provider_name,
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
)
|
|
.first()
|
|
)
|
|
if target_provider is None:
|
|
raise ValueError("provider not found")
|
|
|
|
# clear default provider
|
|
session.query(DatasourceProvider).filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=target_provider.provider,
|
|
plugin_id=target_provider.plugin_id,
|
|
is_default=True,
|
|
).update({"is_default": False})
|
|
|
|
# set new default provider
|
|
target_provider.is_default = True
|
|
session.commit()
|
|
return {"result": "success"}
|
|
|
|
def setup_oauth_custom_client_params(
|
|
self,
|
|
tenant_id: str,
|
|
datasource_provider_id: DatasourceProviderID,
|
|
client_params: dict | None,
|
|
enabled: bool | None,
|
|
):
|
|
"""
|
|
setup oauth custom client params
|
|
"""
|
|
if client_params is None and enabled is None:
|
|
return
|
|
with Session(db.engine) as session:
|
|
tenant_oauth_client_params = (
|
|
session.query(DatasourceOauthTenantParamConfig)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=datasource_provider_id.provider_name,
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
)
|
|
.first()
|
|
)
|
|
|
|
if not tenant_oauth_client_params:
|
|
tenant_oauth_client_params = DatasourceOauthTenantParamConfig(
|
|
tenant_id=tenant_id,
|
|
provider=datasource_provider_id.provider_name,
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
client_params={},
|
|
enabled=False,
|
|
)
|
|
session.add(tenant_oauth_client_params)
|
|
|
|
if client_params is not None:
|
|
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
|
original_params = (
|
|
encrypter.decrypt(tenant_oauth_client_params.client_params) if tenant_oauth_client_params else {}
|
|
)
|
|
new_params: dict = {
|
|
key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
|
|
for key, value in client_params.items()
|
|
}
|
|
tenant_oauth_client_params.client_params = encrypter.encrypt(new_params)
|
|
|
|
if enabled is not None:
|
|
tenant_oauth_client_params.enabled = enabled
|
|
session.commit()
|
|
|
|
def is_system_oauth_params_exist(self, datasource_provider_id: DatasourceProviderID) -> bool:
|
|
"""
|
|
check if system oauth params exist
|
|
"""
|
|
with Session(db.engine).no_autoflush as session:
|
|
return (
|
|
session.query(DatasourceOauthParamConfig)
|
|
.filter_by(provider=datasource_provider_id.provider_name, plugin_id=datasource_provider_id.plugin_id)
|
|
.first()
|
|
is not None
|
|
)
|
|
|
|
def is_tenant_oauth_params_enabled(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> bool:
|
|
"""
|
|
check if tenant oauth params is enabled
|
|
"""
|
|
with Session(db.engine).no_autoflush as session:
|
|
return (
|
|
session.query(DatasourceOauthTenantParamConfig)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=datasource_provider_id.provider_name,
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
enabled=True,
|
|
)
|
|
.count()
|
|
> 0
|
|
)
|
|
|
|
def get_tenant_oauth_client(
|
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID, mask: bool = False
|
|
) -> dict[str, Any] | None:
|
|
"""
|
|
get tenant oauth client
|
|
"""
|
|
with Session(db.engine).no_autoflush as session:
|
|
tenant_oauth_client_params = (
|
|
session.query(DatasourceOauthTenantParamConfig)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=datasource_provider_id.provider_name,
|
|
plugin_id=datasource_provider_id.plugin_id,
|
|
)
|
|
.first()
|
|
)
|
|
if tenant_oauth_client_params:
|
|
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
|
if mask:
|
|
return encrypter.mask_tool_credentials(encrypter.decrypt(tenant_oauth_client_params.client_params))
|
|
else:
|
|
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
|
return None
|
|
|
|
def get_oauth_encrypter(
|
|
self, tenant_id: str, datasource_provider_id: DatasourceProviderID
|
|
) -> tuple[ProviderConfigEncrypter, ProviderConfigCache]:
|
|
"""
|
|
get oauth encrypter
|
|
"""
|
|
datasource_provider = self.provider_manager.fetch_datasource_provider(
|
|
tenant_id=tenant_id, provider_id=str(datasource_provider_id)
|
|
)
|
|
if not datasource_provider.declaration.oauth_schema:
|
|
raise ValueError("Datasource provider oauth schema not found")
|
|
|
|
client_schema = datasource_provider.declaration.oauth_schema.client_schema
|
|
return create_provider_encrypter(
|
|
tenant_id=tenant_id,
|
|
config=[x.to_basic_provider_config() for x in client_schema],
|
|
cache=NoOpProviderCredentialCache(),
|
|
)
|
|
|
|
def get_oauth_client(self, tenant_id: str, datasource_provider_id: DatasourceProviderID) -> dict[str, Any] | None:
|
|
"""
|
|
get oauth client
|
|
"""
|
|
provider = datasource_provider_id.provider_name
|
|
plugin_id = datasource_provider_id.plugin_id
|
|
with Session(db.engine).no_autoflush as session:
|
|
# get tenant oauth client params
|
|
tenant_oauth_client_params = (
|
|
session.query(DatasourceOauthTenantParamConfig)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=provider,
|
|
plugin_id=plugin_id,
|
|
enabled=True,
|
|
)
|
|
.first()
|
|
)
|
|
if tenant_oauth_client_params:
|
|
encrypter, _ = self.get_oauth_encrypter(tenant_id, datasource_provider_id)
|
|
return encrypter.decrypt(tenant_oauth_client_params.client_params)
|
|
|
|
# fallback to system oauth client params
|
|
oauth_client_params = (
|
|
session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
|
)
|
|
if oauth_client_params:
|
|
return oauth_client_params.system_credentials
|
|
|
|
raise ValueError(f"Please configure oauth client params(system/tenant) for {plugin_id}/{provider}")
|
|
|
|
@staticmethod
|
|
def generate_next_datasource_provider_name(
|
|
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
|
|
) -> str:
|
|
db_providers = (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=provider_id.provider_name,
|
|
plugin_id=provider_id.plugin_id,
|
|
)
|
|
.all()
|
|
)
|
|
return generate_incremental_name(
|
|
[provider.name for provider in db_providers],
|
|
f"{credential_type.get_name()}",
|
|
)
|
|
|
|
def add_datasource_oauth_provider(
|
|
self,
|
|
name: str | None,
|
|
tenant_id: str,
|
|
provider_id: DatasourceProviderID,
|
|
avatar_url: str | None,
|
|
credentials: dict,
|
|
) -> None:
|
|
"""
|
|
add datasource oauth provider
|
|
"""
|
|
credential_type = CredentialType.OAUTH2
|
|
with Session(db.engine) as session:
|
|
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
|
|
with redis_client.lock(lock, timeout=20):
|
|
db_provider_name = name
|
|
if not db_provider_name:
|
|
db_provider_name = self.generate_next_datasource_provider_name(
|
|
session=session,
|
|
tenant_id=tenant_id,
|
|
provider_id=provider_id,
|
|
credential_type=credential_type,
|
|
)
|
|
else:
|
|
if (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(
|
|
tenant_id=tenant_id,
|
|
name=db_provider_name,
|
|
provider=provider_id.provider_name,
|
|
plugin_id=provider_id.plugin_id,
|
|
auth_type=credential_type.value,
|
|
)
|
|
.count()
|
|
> 0
|
|
):
|
|
db_provider_name = generate_incremental_name(
|
|
[
|
|
provider.name
|
|
for provider in session.query(DatasourceProvider).filter_by(
|
|
tenant_id=tenant_id,
|
|
provider=provider_id.provider_name,
|
|
plugin_id=provider_id.plugin_id,
|
|
)
|
|
],
|
|
db_provider_name,
|
|
)
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=credential_type
|
|
)
|
|
for key, value in credentials.items():
|
|
if key in provider_credential_secret_variables:
|
|
credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
|
|
|
datasource_provider = DatasourceProvider(
|
|
tenant_id=tenant_id,
|
|
name=db_provider_name,
|
|
provider=provider_id.provider_name,
|
|
plugin_id=provider_id.plugin_id,
|
|
auth_type=credential_type.value,
|
|
encrypted_credentials=credentials,
|
|
avatar_url=avatar_url or "default",
|
|
)
|
|
session.add(datasource_provider)
|
|
session.commit()
|
|
|
|
def add_datasource_api_key_provider(
|
|
self,
|
|
name: str | None,
|
|
tenant_id: str,
|
|
provider_id: DatasourceProviderID,
|
|
credentials: dict,
|
|
) -> None:
|
|
"""
|
|
validate datasource provider credentials.
|
|
|
|
:param tenant_id:
|
|
:param provider:
|
|
:param credentials:
|
|
"""
|
|
provider_name = provider_id.provider_name
|
|
plugin_id = provider_id.plugin_id
|
|
with Session(db.engine) as session:
|
|
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
|
|
with redis_client.lock(lock, timeout=20):
|
|
db_provider_name = name or self.generate_next_datasource_provider_name(
|
|
session=session,
|
|
tenant_id=tenant_id,
|
|
provider_id=provider_id,
|
|
credential_type=CredentialType.API_KEY,
|
|
)
|
|
|
|
# check name is exist
|
|
if (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name, name=db_provider_name)
|
|
.count()
|
|
> 0
|
|
):
|
|
raise ValueError("Authorization name is already exists")
|
|
|
|
try:
|
|
self.provider_manager.validate_provider_credentials(
|
|
tenant_id=tenant_id,
|
|
user_id=current_user.id,
|
|
provider=provider_name,
|
|
plugin_id=plugin_id,
|
|
credentials=credentials,
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to validate credentials: {str(e)}")
|
|
|
|
provider_credential_secret_variables = self.extract_secret_variables(
|
|
tenant_id=tenant_id, provider_id=f"{provider_id}", credential_type=CredentialType.API_KEY
|
|
)
|
|
for key, value in credentials.items():
|
|
if key in provider_credential_secret_variables:
|
|
# if send [__HIDDEN__] in secret input, it will be same as original value
|
|
credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
|
datasource_provider = DatasourceProvider(
|
|
tenant_id=tenant_id,
|
|
name=db_provider_name,
|
|
provider=provider_name,
|
|
plugin_id=plugin_id,
|
|
auth_type=CredentialType.API_KEY.value,
|
|
encrypted_credentials=credentials,
|
|
)
|
|
db.session.add(datasource_provider)
|
|
db.session.commit()
|
|
|
|
def extract_secret_variables(self, tenant_id: str, provider_id: str, credential_type: CredentialType) -> list[str]:
|
|
"""
|
|
Extract secret input form variables.
|
|
|
|
:param credential_form_schemas:
|
|
:return:
|
|
"""
|
|
datasource_provider = self.provider_manager.fetch_datasource_provider(
|
|
tenant_id=tenant_id, provider_id=provider_id
|
|
)
|
|
credential_form_schemas = []
|
|
if credential_type == CredentialType.API_KEY:
|
|
credential_form_schemas = list(datasource_provider.declaration.credentials_schema)
|
|
elif credential_type == CredentialType.OAUTH2:
|
|
if not datasource_provider.declaration.oauth_schema:
|
|
raise ValueError("Datasource provider oauth schema not found")
|
|
credential_form_schemas = list(datasource_provider.declaration.oauth_schema.credentials_schema)
|
|
else:
|
|
raise ValueError(f"Invalid credential type: {credential_type}")
|
|
|
|
secret_input_form_variables = []
|
|
for credential_form_schema in credential_form_schemas:
|
|
if credential_form_schema.type.value == FormType.SECRET_INPUT.value:
|
|
secret_input_form_variables.append(credential_form_schema.name)
|
|
|
|
return secret_input_form_variables
|
|
|
|
def get_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
|
|
"""
|
|
get datasource credentials.
|
|
|
|
:param tenant_id: workspace id
|
|
:param provider_id: provider id
|
|
:return:
|
|
"""
|
|
# Get all provider configurations of the current workspace
|
|
datasource_providers: list[DatasourceProvider] = (
|
|
db.session.query(DatasourceProvider)
|
|
.filter(
|
|
DatasourceProvider.tenant_id == tenant_id,
|
|
DatasourceProvider.provider == provider,
|
|
DatasourceProvider.plugin_id == plugin_id,
|
|
)
|
|
.all()
|
|
)
|
|
if not datasource_providers:
|
|
return []
|
|
copy_credentials_list = []
|
|
default_provider = (
|
|
db.session.query(DatasourceProvider.id)
|
|
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
|
.order_by(DatasourceProvider.is_default.desc(), DatasourceProvider.created_at.asc())
|
|
.first()
|
|
)
|
|
default_provider_id = default_provider.id if default_provider else None
|
|
for datasource_provider in datasource_providers:
|
|
encrypted_credentials = datasource_provider.encrypted_credentials
|
|
# Get provider credential secret variables
|
|
credential_secret_variables = self.extract_secret_variables(
|
|
tenant_id=tenant_id,
|
|
provider_id=f"{plugin_id}/{provider}",
|
|
credential_type=CredentialType.of(datasource_provider.auth_type),
|
|
)
|
|
|
|
# Obfuscate provider credentials
|
|
copy_credentials = encrypted_credentials.copy()
|
|
for key, value in copy_credentials.items():
|
|
if key in credential_secret_variables:
|
|
copy_credentials[key] = encrypter.obfuscated_token(value)
|
|
copy_credentials_list.append(
|
|
{
|
|
"credential": copy_credentials,
|
|
"type": datasource_provider.auth_type,
|
|
"name": datasource_provider.name,
|
|
"avatar_url": datasource_provider.avatar_url,
|
|
"id": datasource_provider.id,
|
|
"is_default": default_provider_id and datasource_provider.id == default_provider_id,
|
|
}
|
|
)
|
|
|
|
return copy_credentials_list
|
|
|
|
def get_all_datasource_credentials(self, tenant_id: str) -> list[dict]:
|
|
"""
|
|
get datasource credentials.
|
|
|
|
:return:
|
|
"""
|
|
# get all plugin providers
|
|
manager = PluginDatasourceManager()
|
|
datasources = manager.fetch_installed_datasource_providers(tenant_id)
|
|
datasource_credentials = []
|
|
for datasource in datasources:
|
|
datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}")
|
|
credentials = self.get_datasource_credentials(
|
|
tenant_id=tenant_id, provider=datasource.provider, plugin_id=datasource.plugin_id
|
|
)
|
|
redirect_uri = (
|
|
f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{datasource_provider_id}/datasource/callback"
|
|
)
|
|
datasource_credentials.append(
|
|
{
|
|
"provider": datasource.provider,
|
|
"plugin_id": datasource.plugin_id,
|
|
"plugin_unique_identifier": datasource.plugin_unique_identifier,
|
|
"icon": datasource.declaration.identity.icon,
|
|
"name": datasource.declaration.identity.name.split("/")[-1],
|
|
"label": datasource.declaration.identity.label.model_dump(),
|
|
"description": datasource.declaration.identity.description.model_dump(),
|
|
"author": datasource.declaration.identity.author,
|
|
"credentials_list": credentials,
|
|
"credential_schema": [
|
|
credential.model_dump() for credential in datasource.declaration.credentials_schema
|
|
],
|
|
"oauth_schema": {
|
|
"client_schema": [
|
|
client_schema.model_dump()
|
|
for client_schema in datasource.declaration.oauth_schema.client_schema
|
|
],
|
|
"credentials_schema": [
|
|
credential_schema.model_dump()
|
|
for credential_schema in datasource.declaration.oauth_schema.credentials_schema
|
|
],
|
|
"oauth_custom_client_params": self.get_tenant_oauth_client(
|
|
tenant_id, datasource_provider_id, mask=True
|
|
),
|
|
"is_oauth_custom_client_enabled": self.is_tenant_oauth_params_enabled(
|
|
tenant_id, datasource_provider_id
|
|
),
|
|
"is_system_oauth_params_exists": self.is_system_oauth_params_exist(datasource_provider_id),
|
|
"redirect_uri": redirect_uri,
|
|
}
|
|
if datasource.declaration.oauth_schema
|
|
else None,
|
|
}
|
|
)
|
|
return datasource_credentials
|
|
|
|
def get_real_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> list[dict]:
|
|
"""
|
|
get datasource credentials.
|
|
|
|
:param tenant_id: workspace id
|
|
:param provider_id: provider id
|
|
:return:
|
|
"""
|
|
# Get all provider configurations of the current workspace
|
|
datasource_providers: list[DatasourceProvider] = (
|
|
db.session.query(DatasourceProvider)
|
|
.filter(
|
|
DatasourceProvider.tenant_id == tenant_id,
|
|
DatasourceProvider.provider == provider,
|
|
DatasourceProvider.plugin_id == plugin_id,
|
|
)
|
|
.all()
|
|
)
|
|
if not datasource_providers:
|
|
return []
|
|
copy_credentials_list = []
|
|
for datasource_provider in datasource_providers:
|
|
encrypted_credentials = datasource_provider.encrypted_credentials
|
|
# Get provider credential secret variables
|
|
credential_secret_variables = self.extract_secret_variables(
|
|
tenant_id=tenant_id,
|
|
provider_id=f"{plugin_id}/{provider}",
|
|
credential_type=CredentialType.of(datasource_provider.auth_type),
|
|
)
|
|
|
|
# Obfuscate provider credentials
|
|
copy_credentials = encrypted_credentials.copy()
|
|
for key, value in copy_credentials.items():
|
|
if key in credential_secret_variables:
|
|
copy_credentials[key] = encrypter.decrypt_token(tenant_id, value)
|
|
copy_credentials_list.append(
|
|
{
|
|
"credentials": copy_credentials,
|
|
"type": datasource_provider.auth_type,
|
|
}
|
|
)
|
|
|
|
return copy_credentials_list
|
|
|
|
def update_datasource_credentials(
|
|
self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict | None, name: str | None
|
|
) -> None:
|
|
"""
|
|
update datasource credentials.
|
|
"""
|
|
with Session(db.engine) as session:
|
|
datasource_provider = (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
|
.first()
|
|
)
|
|
if not datasource_provider:
|
|
raise ValueError("Datasource provider not found")
|
|
# update name
|
|
if name and name != datasource_provider.name:
|
|
if (
|
|
session.query(DatasourceProvider)
|
|
.filter_by(tenant_id=tenant_id, name=name, provider=provider, plugin_id=plugin_id)
|
|
.count()
|
|
> 0
|
|
):
|
|
raise ValueError("Authorization name is already exists")
|
|
datasource_provider.name = name
|
|
|
|
# update credentials
|
|
if credentials:
|
|
secret_variables = self.extract_secret_variables(
|
|
tenant_id=tenant_id,
|
|
provider_id=f"{plugin_id}/{provider}",
|
|
credential_type=CredentialType.of(datasource_provider.auth_type),
|
|
)
|
|
original_credentials = {
|
|
key: value if key not in secret_variables else encrypter.decrypt_token(tenant_id, value)
|
|
for key, value in datasource_provider.encrypted_credentials.items()
|
|
}
|
|
new_credentials = {
|
|
key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
|
|
for key, value in credentials.items()
|
|
}
|
|
try:
|
|
self.provider_manager.validate_provider_credentials(
|
|
tenant_id=tenant_id,
|
|
user_id=current_user.id,
|
|
provider=provider,
|
|
plugin_id=plugin_id,
|
|
credentials=new_credentials,
|
|
)
|
|
except Exception as e:
|
|
raise ValueError(f"Failed to validate credentials: {str(e)}")
|
|
|
|
encrypted_credentials = {}
|
|
for key, value in new_credentials.items():
|
|
if key in secret_variables:
|
|
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
|
else:
|
|
encrypted_credentials[key] = value
|
|
|
|
datasource_provider.encrypted_credentials = encrypted_credentials
|
|
session.commit()
|
|
|
|
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
|
|
"""
|
|
remove datasource credentials.
|
|
|
|
:param tenant_id: workspace id
|
|
:param provider: provider name
|
|
:param plugin_id: plugin id
|
|
:return:
|
|
"""
|
|
datasource_provider = (
|
|
db.session.query(DatasourceProvider)
|
|
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
|
|
.first()
|
|
)
|
|
if datasource_provider:
|
|
db.session.delete(datasource_provider)
|
|
db.session.commit()
|