feat: refactor OAuth provider handling and improve provider name generation

This commit is contained in:
Harry 2025-07-18 12:47:32 +08:00
parent 9f2a9ad271
commit 0ac5c0bf3e
4 changed files with 256 additions and 155 deletions

View File

@ -1,6 +1,5 @@
import random
from flask import redirect, request
from fastapi.encoders import jsonable_encoder
from flask import make_response, redirect, request
from flask_login import current_user # type: ignore
from flask_restful import ( # type: ignore
Resource, # type: ignore
@ -15,76 +14,101 @@ from controllers.console.wraps import (
setup_required,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.oauth import OAuthHandler
from extensions.ext_database import db
from libs.login import login_required
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.oauth import DatasourceOauthParamConfig
from services.datasource_provider_service import DatasourceProviderService
from services.plugin.oauth_service import OAuthProxyService
class DatasourcePluginOauthApi(Resource):
class DatasourcePluginOAuthAuthorizationUrl(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
# Check user role first
def get(self, provider: str):
user = current_user
tenant_id = user.current_tenant_id
if not current_user.is_editor:
raise Forbidden()
# get all plugin oauth configs
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
provider_id = DatasourceProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
)
if not oauth_config:
raise ValueError(f"No OAuth Client Config for {provider}")
context_id = OAuthProxyService.create_proxy_context(
user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
)
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler()
redirect_url = (
f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider}/datasource/callback"
oauth_client_params = oauth_config.system_credentials
authorization_url_response = oauth_handler.get_authorization_url(
tenant_id=tenant_id,
user_id=user.id,
plugin_id=plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
)
system_credentials = plugin_oauth_config.system_credentials
if system_credentials:
system_credentials["redirect_url"] = redirect_url
response = oauth_handler.get_authorization_url(
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
response = make_response(jsonable_encoder(authorization_url_response))
response.set_cookie(
"context_id",
context_id,
httponly=True,
samesite="Lax",
max_age=OAuthProxyService.__MAX_AGE__,
)
return response.model_dump()
return response
class DatasourceOauthCallback(Resource):
class DatasourceOAuthCallback(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
oauth_handler = OAuthHandler()
def get(self, provider: str):
if not current_user.is_editor:
raise Forbidden()
context_id = request.cookies.get("context_id")
if not context_id:
raise Forbidden("context_id not found")
context = OAuthProxyService.use_proxy_context(context_id)
if context is None:
raise Forbidden("Invalid context_id")
user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
provider_id = DatasourceProviderID(provider)
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider_name, plugin_id=plugin_id).first()
)
if not plugin_oauth_config:
raise NotFound()
credentials = oauth_handler.get_credentials(
current_user.current_tenant.id,
current_user.id,
plugin_id,
provider,
redirect_uri = f"{dify_config.CONSOLE_WEB_URL}/console/api/oauth/plugin/{provider}/datasource/callback"
oauth_handler = OAuthHandler()
oauth_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
provider=provider_id.provider_name,
redirect_uri=redirect_uri,
system_credentials=plugin_oauth_config.system_credentials,
request=request,
)
datasource_provider = DatasourceProvider(
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.add_datasource_oauth_provider(
tenant_id=tenant_id,
provider_id=provider_id,
credentials=dict(oauth_response.credentials),
name=None,
)
db.session.add(datasource_provider)
db.session.commit()
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
@ -92,26 +116,23 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
def post(self, provider: str):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=False, nullable=True, location="json", default=None)
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
provider_id = DatasourceProviderID(provider)
datasource_provider_service = DatasourceProviderService()
try:
datasource_provider_service.datasource_provider_credentials_validate(
datasource_provider_service.add_datasource_api_key_provider(
tenant_id=current_user.current_tenant_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
provider_id=provider_id,
credentials=args["credentials"],
name="test" + str(random.randint(1, 1000000)), # noqa: S311
name=args["name"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
@ -121,14 +142,13 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
def get(self, provider: str):
provider_id = DatasourceProviderID(provider)
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
tenant_id=current_user.current_tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
)
return {"result": datasources}, 200
@ -137,29 +157,27 @@ class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, auth_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
def delete(self, provider: str, auth_id: str):
provider_id = DatasourceProviderID(provider)
plugin_id = provider_id.plugin_id
provider_name = provider_id.provider_name
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
provider=provider_name,
plugin_id=plugin_id,
)
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def patch(self, auth_id: str):
def patch(self, provider: str, auth_id: str):
provider_id = DatasourceProviderID(provider)
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
@ -169,8 +187,8 @@ class DatasourceAuthUpdateDeleteApi(Resource):
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
@ -193,21 +211,21 @@ class DatasourceAuthListApi(Resource):
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
"/oauth/plugin/datasource",
DatasourcePluginOAuthAuthorizationUrl,
"/oauth/plugin/<path:provider>/datasource/get-authorization-url",
)
api.add_resource(
DatasourceOauthCallback,
"/oauth/plugin/datasource/callback",
DatasourceOAuthCallback,
"/oauth/plugin/<path:provider>/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource",
"/auth/plugin/datasource/<path:provider>",
)
api.add_resource(
DatasourceAuthUpdateDeleteApi,
"/auth/plugin/datasource/<string:auth_id>",
"/auth/plugin/datasource/<path:provider>/<string:auth_id>",
)
api.add_resource(

View File

@ -0,0 +1,35 @@
import logging
import re
from collections.abc import Sequence
from typing import Any
from core.tools.entities.tool_entities import CredentialType
logger = logging.getLogger(__name__)
def generate_provider_name(
providers: Sequence[Any],
credential_type: CredentialType,
fallback_context: str = "provider"
) -> str:
try:
default_pattern = f"{credential_type.get_name()}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for provider in providers:
if provider.name:
match = re.match(pattern, provider.name.strip())
if match:
numbers.append(int(match.group(1)))
if not numbers:
return f"{default_pattern} 1"
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {fallback_context}: {str(e)}")
return f"{credential_type.get_name()} 1"

View File

@ -1,13 +1,18 @@
import logging
from flask_login import current_user
from sqlalchemy.orm import Session
from constants import HIDDEN_VALUE
from core.helper import encrypter
from core.helper.provider_name_generator import generate_provider_name
from core.model_runtime.entities.provider_entities import FormType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin import DatasourceProviderID
from core.plugin.impl.datasource import PluginDatasourceManager
from core.tools.entities.tool_entities import CredentialType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.oauth import DatasourceProvider
logger = logging.getLogger(__name__)
@ -21,8 +26,71 @@ class DatasourceProviderService:
def __init__(self) -> None:
self.provider_manager = PluginDatasourceManager()
def datasource_provider_credentials_validate(
self, tenant_id: str, provider: str, plugin_id: str, credentials: dict, name: str
@staticmethod
def generate_next_datasource_provider_name(
session: Session, tenant_id: str, provider_id: DatasourceProviderID, credential_type: CredentialType
) -> str:
db_providers = (
session.query(DatasourceProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
)
.all()
)
return generate_provider_name(db_providers, credential_type, f"datasource provider {provider_id}")
def add_datasource_oauth_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
credentials: dict,
) -> None:
"""
add datasource oauth provider
"""
credential_type = CredentialType.OAUTH2
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{credential_type.value}"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=credential_type,
)
if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0:
raise ValueError("name is already exists")
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}"
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_id.provider_name,
plugin_id=provider_id.plugin_id,
auth_type=credential_type.value,
encrypted_credentials=credentials,
)
session.add(datasource_provider)
session.commit()
def add_datasource_api_key_provider(
self,
name: str | None,
tenant_id: str,
provider_id: DatasourceProviderID,
credentials: dict,
) -> None:
"""
validate datasource provider credentials.
@ -31,45 +99,49 @@ class DatasourceProviderService:
:param provider:
:param credentials:
"""
# check name is exist
datasource_provider = db.session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=name).first()
if datasource_provider:
raise ValueError("Authorization name is already exists")
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_api_key"
with redis_client.lock(lock, timeout=20):
db_provider_name = name or self.generate_next_datasource_provider_name(
session=session,
tenant_id=tenant_id,
provider_id=provider_id,
credential_type=CredentialType.API_KEY,
)
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider,
plugin_id=plugin_id,
credentials=credentials,
)
if credential_valid:
# Get all provider configurations of the current workspace
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key")
.first()
)
# check name is exist
if session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, name=db_provider_name).count() > 0:
raise ValueError("Authorization name is already exists")
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}"
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=name,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
else:
raise CredentialsValidateFailedError()
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
provider=provider_name,
plugin_id=plugin_id,
credentials=credentials,
)
if credential_valid:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id, provider_id=f"{provider_id}"
)
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
name=db_provider_name,
provider=provider_name,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
else:
raise CredentialsValidateFailedError()
def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
"""

View File

@ -1,6 +1,5 @@
import json
import logging
import re
from collections.abc import Mapping
from pathlib import Path
from typing import Any, Optional
@ -11,6 +10,7 @@ from configs import dify_config
from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.helper.provider_name_generator import generate_provider_name
from core.plugin.entities.plugin import ToolProviderID
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
@ -299,42 +299,18 @@ class BuiltinToolManageService:
def generate_builtin_tool_provider_name(
session: Session, tenant_id: str, provider: str, credential_type: CredentialType
) -> str:
try:
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.all()
db_providers = (
session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider,
credential_type=credential_type.value,
)
# Get the default name pattern
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{credential_type.get_name()} 1"
.order_by(BuiltinToolProvider.created_at.desc())
.all()
)
return generate_provider_name(db_providers, credential_type, f"builtin tool provider {provider}")
@staticmethod
def get_builtin_tool_provider_credentials(