mirror of
https://github.com/langgenius/dify.git
synced 2025-12-24 08:32:16 +00:00
r2
This commit is contained in:
parent
abcca11479
commit
c09c8c6e5b
@ -24,7 +24,13 @@ class DatasourcePluginOauthApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, plugin_id):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
@ -35,7 +41,7 @@ class DatasourcePluginOauthApi(Resource):
|
||||
if not plugin_oauth_config:
|
||||
raise NotFound()
|
||||
oauth_handler = OAuthHandler()
|
||||
redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/provider/{provider}/plugin/{plugin_id}/callback"
|
||||
redirect_url = f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
|
||||
system_credentials = plugin_oauth_config.system_credentials
|
||||
if system_credentials:
|
||||
system_credentials["redirect_url"] = redirect_url
|
||||
@ -49,7 +55,13 @@ class DatasourceOauthCallback(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, plugin_id):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
provider = args["provider"]
|
||||
plugin_id = args["plugin_id"]
|
||||
oauth_handler = OAuthHandler()
|
||||
plugin_oauth_config = (
|
||||
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||
@ -76,11 +88,13 @@ class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self, provider, plugin_id):
|
||||
def post(self):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
|
||||
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -89,8 +103,8 @@ class DatasourceAuth(Resource):
|
||||
try:
|
||||
datasource_provider_service.datasource_provider_credentials_validate(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=provider,
|
||||
plugin_id=plugin_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"],
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
@ -101,10 +115,16 @@ class DatasourceAuth(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider, plugin_id):
|
||||
def get(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasources = datasource_provider_service.get_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"]
|
||||
)
|
||||
return {"result": datasources}, 200
|
||||
|
||||
@ -113,12 +133,18 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, provider, plugin_id):
|
||||
def delete(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||
args = parser.parse_args()
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
datasource_provider_service = DatasourceProviderService()
|
||||
datasource_provider_service.remove_datasource_credentials(
|
||||
tenant_id=current_user.current_tenant_id, provider=provider, plugin_id=plugin_id
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
provider=args["provider"],
|
||||
plugin_id=args["plugin_id"]
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
|
||||
@ -126,13 +152,13 @@ class DatasourceAuthDeleteApi(Resource):
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
DatasourcePluginOauthApi,
|
||||
"/oauth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
|
||||
"/oauth/plugin/datasource",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceOauthCallback,
|
||||
"/oauth/datasource/provider/<string:provider>/plugin/<string:plugin_id>/callback",
|
||||
"/oauth/plugin/datasource/callback",
|
||||
)
|
||||
api.add_resource(
|
||||
DatasourceAuth,
|
||||
"/auth/datasource/provider/<string:provider>/plugin/<string:plugin_id>",
|
||||
"/auth/plugin/datasource",
|
||||
)
|
||||
|
||||
@ -24,7 +24,7 @@ class DatasourceManager:
|
||||
|
||||
@classmethod
|
||||
def get_datasource_plugin_provider(
|
||||
cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||
) -> DatasourcePluginProviderController:
|
||||
"""
|
||||
get the datasource plugin provider
|
||||
@ -38,13 +38,13 @@ class DatasourceManager:
|
||||
|
||||
with contexts.datasource_plugin_providers_lock.get():
|
||||
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
|
||||
if provider in datasource_plugin_providers:
|
||||
return datasource_plugin_providers[provider]
|
||||
if provider_id in datasource_plugin_providers:
|
||||
return datasource_plugin_providers[provider_id]
|
||||
|
||||
manager = PluginDatasourceManager()
|
||||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider)
|
||||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
@ -71,7 +71,7 @@ class DatasourceManager:
|
||||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
datasource_plugin_providers[provider] = controller
|
||||
datasource_plugin_providers[provider_id] = controller
|
||||
|
||||
return controller
|
||||
|
||||
|
||||
@ -40,16 +40,25 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
)
|
||||
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
|
||||
return [local_file_datasource_provider] + response
|
||||
all_response = [local_file_datasource_provider] + response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider: str) -> PluginDatasourceProviderEntity:
|
||||
for provider in all_response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.datasources:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return all_response
|
||||
|
||||
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
|
||||
"""
|
||||
Fetch datasource provider for the given tenant and plugin.
|
||||
"""
|
||||
if provider == "langgenius/file/file":
|
||||
if provider_id == "langgenius/file/file":
|
||||
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
tool_provider_id = ToolProviderID(provider_id)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
@ -225,13 +234,13 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
|
||||
return {
|
||||
"id": "langgenius/file/file",
|
||||
"plugin_id": "langgenius/file/file",
|
||||
"provider": "langgenius",
|
||||
"plugin_id": "langgenius/file",
|
||||
"provider": "file",
|
||||
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
|
||||
"declaration": {
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "langgenius/file/file",
|
||||
"name": "file",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
"icon": "https://cloud.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id=945b4365-9d99-48c1-8c47-90593fe8b9c9&filename=13d9312f6b1352d3939b90a5257de58ff3cd619d5be4f5b266ff0298935ac328.svg",
|
||||
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
@ -243,7 +252,7 @@ class PluginDatasourceManager(BasePluginClient):
|
||||
"identity": {
|
||||
"author": "langgenius",
|
||||
"name": "upload-file",
|
||||
"provider": "langgenius",
|
||||
"provider": "file",
|
||||
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||
},
|
||||
"parameters": [],
|
||||
|
||||
@ -25,12 +25,12 @@ class DatasourceProvider(Base):
|
||||
__tablename__ = "datasource_providers"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
|
||||
db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_plugin_id_provider_idx"),
|
||||
db.UniqueConstraint("plugin_id", "provider", "auth_type", name="datasource_provider_auth_type_provider_idx"),
|
||||
)
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False)
|
||||
provider: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
plugin_id: Mapped[str] = db.Column(db.TEXT, nullable=False)
|
||||
auth_type: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
encrypted_credentials: Mapped[dict] = db.Column(JSONB, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, default=datetime.now)
|
||||
|
||||
@ -38,11 +38,14 @@ class DatasourceProviderService:
|
||||
# Get all provider configurations of the current workspace
|
||||
datasource_provider = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
|
||||
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, auth_type="api_key")
|
||||
.first()
|
||||
)
|
||||
|
||||
provider_credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}"
|
||||
)
|
||||
if not datasource_provider:
|
||||
for key, value in credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
@ -73,14 +76,16 @@ class DatasourceProviderService:
|
||||
else:
|
||||
raise CredentialsValidateFailedError()
|
||||
|
||||
def extract_secret_variables(self, tenant_id: str, provider: str) -> list[str]:
|
||||
def extract_secret_variables(self, tenant_id: str, provider_id: str) -> list[str]:
|
||||
"""
|
||||
Extract secret input form variables.
|
||||
|
||||
:param credential_form_schemas:
|
||||
:return:
|
||||
"""
|
||||
datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id, provider=provider)
|
||||
datasource_provider = self.provider_manager.fetch_datasource_provider(tenant_id=tenant_id,
|
||||
provider_id=provider_id
|
||||
)
|
||||
credential_form_schemas = datasource_provider.declaration.credentials_schema
|
||||
secret_input_form_variables = []
|
||||
for credential_form_schema in credential_form_schemas:
|
||||
@ -94,8 +99,7 @@ class DatasourceProviderService:
|
||||
get datasource credentials.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param provider: provider name
|
||||
:param plugin_id: plugin id
|
||||
:param provider_id: provider id
|
||||
:return:
|
||||
"""
|
||||
# Get all provider configurations of the current workspace
|
||||
@ -114,7 +118,7 @@ class DatasourceProviderService:
|
||||
for datasource_provider in datasource_providers:
|
||||
encrypted_credentials = datasource_provider.encrypted_credentials
|
||||
# Get provider credential secret variables
|
||||
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider=provider)
|
||||
credential_secret_variables = self.extract_secret_variables(tenant_id=tenant_id, provider_id=provider)
|
||||
|
||||
# Obfuscate provider credentials
|
||||
copy_credentials = encrypted_credentials.copy()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user