This commit is contained in:
jyong 2025-06-04 17:29:39 +08:00
parent a82ab1d152
commit 8a147a00e8
3 changed files with 91 additions and 37 deletions

View File

@ -122,18 +122,18 @@ class DatasourceAuth(Resource):
args = parser.parse_args()
datasource_provider_service = DatasourceProviderService()
datasources = datasource_provider_service.get_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=args["provider"],
tenant_id=current_user.current_tenant_id,
provider=args["provider"],
plugin_id=args["plugin_id"]
)
return {"result": datasources}, 200
class DatasourceAuthDeleteApi(Resource):
class DatasourceAuthUpdateDeleteApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self):
def delete(self, auth_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
@ -142,12 +142,38 @@ class DatasourceAuthDeleteApi(Resource):
raise Forbidden()
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.remove_datasource_credentials(
tenant_id=current_user.current_tenant_id,
provider=args["provider"],
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
provider=args["provider"],
plugin_id=args["plugin_id"]
)
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
def patch(self, auth_id: str):
parser = reqparse.RequestParser()
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
args = parser.parse_args()
if not current_user.is_editor:
raise Forbidden()
try:
datasource_provider_service = DatasourceProviderService()
datasource_provider_service.update_datasource_credentials(
tenant_id=current_user.current_tenant_id,
auth_id=auth_id,
provider=args["provider"],
plugin_id=args["plugin_id"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
raise ValueError(str(ex))
return {"result": "success"}, 201
# Import Rag Pipeline
api.add_resource(
@ -162,3 +188,8 @@ api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource",
)
api.add_resource(
DatasourceAuth,
"/auth/plugin/datasource/<string:auth_id>",
)

View File

@ -25,7 +25,7 @@ class DatasourceProvider(Base):
__tablename__ = "datasource_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="datasource_provider_pkey"),
db.UniqueConstraint("plugin_id", "provider", "auth_type", name="datasource_provider_auth_type_provider_idx"),
db.UniqueConstraint("plugin_id", "provider", name="datasource_provider_auth_type_provider_idx"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)

View File

@ -38,7 +38,7 @@ class DatasourceProviderService:
# Get all provider configurations of the current workspace
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, auth_type="api_key")
.filter_by(tenant_id=tenant_id, plugin_id=plugin_id, provider=provider, auth_type="api_key")
.first()
)
@ -46,33 +46,19 @@ class DatasourceProviderService:
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}"
)
if not datasource_provider:
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
else:
original_credentials = datasource_provider.encrypted_credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
else:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider.encrypted_credentials = credentials
db.session.commit()
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider = DatasourceProvider(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
auth_type="api_key",
encrypted_credentials=credentials,
)
db.session.add(datasource_provider)
db.session.commit()
else:
raise CredentialsValidateFailedError()
@ -133,8 +119,45 @@ class DatasourceProviderService:
)
return copy_credentials_list
def update_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str, credentials: dict) -> None:
"""
update datasource credentials.
"""
credential_valid = self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id, user_id=current_user.id, provider=provider, credentials=credentials
)
if credential_valid:
# Get all provider configurations of the current workspace
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
def remove_datasource_credentials(self, tenant_id: str, provider: str, plugin_id: str) -> None:
provider_credential_secret_variables = self.extract_secret_variables(
tenant_id=tenant_id,
provider_id=f"{plugin_id}/{provider}"
)
if not datasource_provider:
raise ValueError("Datasource provider not found")
else:
original_credentials = datasource_provider.encrypted_credentials
for key, value in credentials.items():
if key in provider_credential_secret_variables:
# if send [__HIDDEN__] in secret input, it will be same as original value
if value == HIDDEN_VALUE and key in original_credentials:
original_value = encrypter.encrypt_token(tenant_id, original_credentials[key])
credentials[key] = encrypter.encrypt_token(tenant_id, original_value)
else:
credentials[key] = encrypter.encrypt_token(tenant_id, value)
datasource_provider.encrypted_credentials = credentials
db.session.commit()
else:
raise CredentialsValidateFailedError()
def remove_datasource_credentials(self, tenant_id: str, auth_id: str, provider: str, plugin_id: str) -> None:
"""
remove datasource credentials.
@ -145,7 +168,7 @@ class DatasourceProviderService:
"""
datasource_provider = (
db.session.query(DatasourceProvider)
.filter_by(tenant_id=tenant_id, provider=provider, plugin_id=plugin_id)
.filter_by(tenant_id=tenant_id, id=auth_id, provider=provider, plugin_id=plugin_id)
.first()
)
if datasource_provider: