This commit is contained in:
jyong 2025-06-04 15:12:05 +08:00
parent abcca11479
commit c09c8c6e5b
5 changed files with 75 additions and 36 deletions

View File

@ -24,7 +24,13 @@ class DatasourcePluginOauthApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
# Check user role first
if not current_user.is_editor:
raise Forbidden()
@ -35,7 +41,7 @@ class DatasourcePluginOauthApi(Resource):
if not plugin_oauth_config:
raise NotFound()
oauth_handler = OAuthHandler()
redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/provider/{provider}/plugin/{plugin_id}/callback"
redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
system_credentials = plugin_oauth_config.system_credentials
if system_credentials:
system_credentials["redirect_url"] = redirect_url
@ -49,7 +55,13 @@ class DatasourceOauthCallback(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
provider = args["provider"]
plugin_id = args["plugin_id"]
oauth_handler = OAuthHandler()
plugin_oauth_config = (
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
@ -76,11 +88,13 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self, provider, plugin_id):
def post(self):
if not current_user.is_editor:
raise Forbidden()
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
@ -89,8 +103,8 @@ class DatasourceAuth(Resource):
try:
datasource_provider_service.datasource_provider_credentials_validate(
tenant_id=current_user.current_tenant_id,
provider=provider,
plugin_id=plugin_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
@ -101,10 +115,16 @@ class DatasourceAuth(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, provider, plugin_id):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
tenant_id=current_user.current_tenant_id,
provider=args["provider"],
plugin_id=args["plugin_id"]
)
return {"result": datasources}, 200
@ -113,12 +133,18 @@ class DatasourceAuthDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, provider, plugin_id):
def delete(self):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
tenant_id=current_user.current_tenant_id,
provider=args["provider"],
plugin_id=args["plugin_id"]
)
return {"result": "success"}, 200
@ -126,13 +152,13 @@ class DatasourceAuthDeleteApi(Resource):
# Import Rag Pipeline
api.add_resource(
DatasourcePluginOauthApi,
"/oauth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
"/oauth/plugin/datasource",
)
api.add_resource(
DatasourceOauthCallback,
"/oauth/datasource/provider/<string:provider>/plugin/<string:plugin_id>/callback",
"/oauth/plugin/datasource/callback",
)
api.add_resource(
DatasourceAuth,
"/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
"/auth/plugin/datasource",
)

View File

@ -24,7 +24,7 @@ class DatasourceManager:
@classmethod
def get_datasource_plugin_provider(
cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
"""
get the datasource plugin provider
@ -38,13 +38,13 @@ class DatasourceManager:
with contexts.datasource_plugin_providers_lock.get():
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
if provider in datasource_plugin_providers:
return datasource_plugin_providers[provider]
if provider_id in datasource_plugin_providers:
return datasource_plugin_providers[provider_id]
manager = PluginDatasourceManager()
provider_entity = manager.fetch_datasource_provider(tenant_id, provider)
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
@ -71,7 +71,7 @@ class DatasourceManager:
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
datasource_plugin_providers[provider] = controller
datasource_plugin_providers[provider_id] = controller
return controller

View File

@ -40,16 +40,25 @@ class PluginDatasourceManager(BasePluginClient):
)
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
return [local_file_datasource_provider] + response
all_response = [local_file_datasource_provider] + response
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
for provider in all_response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.datasources:
tool.identity.provider = provider.declaration.identity.name
return all_response
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
"""
Fetch datasource provider for the given tenant and plugin.
"""
if provider == "langgenius/file/file":
if provider_id == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
tool_provider_id = ToolProviderID(provider)
tool_provider_id = ToolProviderID(provider_id)
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
@ -225,13 +234,13 @@ class PluginDatasourceManager(BasePluginClient):
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
return {
"id": "langgenius/file/file",
"plugin_id": "langgenius/file/file",
"provider": "langgenius",
"plugin_id": "langgenius/file",
"provider": "file",
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
"declaration": {
"identity": {
"author": "langgenius",
"name": "langgenius/file/file",
"name": "file",
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
@ -243,7 +252,7 @@ class PluginDatasourceManager(BasePluginClient):
"identity": {
"author": "langgenius",
"name": "upload-file",
"provider": "langgenius",
"provider": "file",
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
},
"parameters": [],

View File

@ -25,12 +25,12 @@ class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"),
db.UniqueConstraint("plugin_id", "provider", "auth_type", name="datasource_provider_auth_type_provider_idx"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False)
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False)
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)

View File

@ -38,11 +38,14 @@ class DatasourceProviderService:
# Get all provider configurations of the current workspace
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, auth_type="api_key")
.first()
)
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}"
)
if not datasource_provider:
for key, value in credentials.items():
if key in provider_credential_secret_variables:
@ -73,14 +76,16 @@ class DatasourceProviderService:
else:
raise CredentialsValidateFailedError()
def extract_secret_variables(self, tenant_id: str, provider: str) -> list[str]:
def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
"""
Extract secret input form variables.
:param credential_form_schemas:
:return:
"""
datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, provider=provider)
datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id,
provider_id=provider_id
)
credential_form_schemas = datasource_provider.declaration.credentials_schema
secret_input_form_variables = []
for credential_form_schema in credential_form_schemas:
@ -94,8 +99,7 @@ class DatasourceProviderService:
get datasource credentials.
:param tenant_id: workspace id
:param provider: provider name
:param plugin_id: plugin id
:param provider_id: provider id
:return:
"""
# Get all provider configurations of the current workspace
@ -114,7 +118,7 @@ class DatasourceProviderService:
for datasource_provider in datasource_providers:
encrypted_credentials = datasource_provider.encrypted_credentials
# Get provider credential secret variables
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=provider)
# Obfuscate provider credentials
copy_credentials = encrypted_credentials.copy()