diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index c78b36c3b9..bc91343c71 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -122,18 +122,18 @@ class DatasourceAuth(Resource): args = parser.parse_args() datasource_provider_service = DatasourceProviderService() datasources = datasource_provider_service.get_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=args["provider"], + tenant_id=current_user.current_tenant_id, + provider=args["provider"], plugin_id=args["plugin_id"] ) return {"result": datasources}, 200 -class DatasourceAuthDeleteApi(Resource): +class DatasourceAuthUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - def delete(self): + def delete(self, auth_id: str): parser = reqparse.RequestParser() parser.add_argument("provider", type=str, required=True, nullable=False, location="args") parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") @@ -142,12 +142,38 @@ class DatasourceAuthDeleteApi(Resource): raise Forbidden() datasource_provider_service = DatasourceProviderService() datasource_provider_service.remove_datasource_credentials( - tenant_id=current_user.current_tenant_id, - provider=args["provider"], + tenant_id=current_user.current_tenant_id, + auth_id=auth_id, + provider=args["provider"], plugin_id=args["plugin_id"] ) return {"result": "success"}, 200 + @setup_required + @login_required + @account_initialization_required + def patch(self, auth_id: str): + parser = reqparse.RequestParser() + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") + parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + args = parser.parse_args() + if not current_user.is_editor: + raise Forbidden() + try: + datasource_provider_service = DatasourceProviderService() + datasource_provider_service.update_datasource_credentials( + tenant_id=current_user.current_tenant_id, + auth_id=auth_id, + provider=args["provider"], + plugin_id=args["plugin_id"], + credentials=args["credentials"], + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {"result": "success"}, 201 + # Import Rag Pipeline api.add_resource( @@ -162,3 +188,8 @@ api.add_resource( DatasourceAuth, "/auth/plugin/datasource", ) + +api.add_resource( + DatasourceAuth, + "/auth/plugin/datasource/", +) diff --git a/api/models/oauth.py b/api/models/oauth.py index d823bcae16..938a309069 100644 --- a/api/models/oauth.py +++ b/api/models/oauth.py @@ -25,7 +25,7 @@ class DatasourceProvider(Base): __tablename__ = "datasource_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"), - db.UniqueConstraint("plugin_id", "provider", "auth_type", name="datasource_provider_auth_type_provider_idx"), + db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_auth_type_provider_idx"), ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=False) diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 71edec760f..1344dfa9fe 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -38,7 +38,7 @@ class DatasourceProviderService: # Get all provider configurations of the current workspace datasource_provider = ( db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, auth_type="api_key") + .filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key") .first() ) @@ -46,33 +46,19 @@ class DatasourceProviderService: tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}" ) - if not datasource_provider: - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - credentials[key] = encrypter.encrypt_token(tenant_id, value) - datasource_provider = DatasourceProvider( - tenant_id=tenant_id, - provider=provider, - plugin_id=plugin_id, - auth_type="api_key", - encrypted_credentials=credentials, - ) - db.session.add(datasource_provider) - db.session.commit() - else: - original_credentials = datasource_provider.encrypted_credentials - for key, value in credentials.items(): - if key in provider_credential_secret_variables: - # if send [__HIDDEN__] in secret input, it will be same as original value - if value == HIDDEN_VALUE and key in original_credentials: - original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) - credentials[key] = encrypter.encrypt_token(tenant_id, original_value) - else: - credentials[key] = encrypter.encrypt_token(tenant_id, value) - - datasource_provider.encrypted_credentials = credentials - db.session.commit() + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + credentials[key] = encrypter.encrypt_token(tenant_id, value) + datasource_provider = DatasourceProvider( + tenant_id=tenant_id, + provider=provider, + plugin_id=plugin_id, + auth_type="api_key", + encrypted_credentials=credentials, + ) + db.session.add(datasource_provider) + db.session.commit() else: raise CredentialsValidateFailedError() @@ -133,8 +119,45 @@ class DatasourceProviderService: ) return copy_credentials_list + + def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None: + """ + update datasource credentials. + """ + credential_valid = self.provider_manager.validate_provider_credentials( + tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials + ) + if credential_valid: + # Get all provider configurations of the current workspace + datasource_provider = ( + db.session.query(DatasourceProvider) + .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) + .first() + ) - def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None: + provider_credential_secret_variables = self.extract_secret_variables( + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}" + ) + if not datasource_provider: + raise ValueError("Datasource provider not found") + else: + original_credentials = datasource_provider.encrypted_credentials + for key, value in credentials.items(): + if key in provider_credential_secret_variables: + # if send [__HIDDEN__] in secret input, it will be same as original value + if value == HIDDEN_VALUE and key in original_credentials: + original_value = encrypter.encrypt_token(tenant_id, original_credentials[key]) + credentials[key] = encrypter.encrypt_token(tenant_id, original_value) + else: + credentials[key] = encrypter.encrypt_token(tenant_id, value) + + datasource_provider.encrypted_credentials = credentials + db.session.commit() + else: + raise CredentialsValidateFailedError() + + def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None: """ remove datasource credentials. @@ -145,7 +168,7 @@ class DatasourceProviderService: """ datasource_provider = ( db.session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id) + .filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id) .first() ) if datasource_provider: