From 657e813c7ff5ebd90ca23fc06af89c348a19cf7a Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Mon, 28 Jul 2025 19:29:36 +0800 Subject: [PATCH] add old auth transform --- api/commands.py | 8 ++--- .../console/datasets/data_source.py | 4 +-- .../rag/extractor/entity/extract_setting.py | 1 + .../firecrawl/firecrawl_web_extractor.py | 16 ++++++++-- .../rag/extractor/jina_reader_extractor.py | 16 ++++++++-- api/core/rag/extractor/notion_extractor.py | 2 -- .../rag/extractor/watercrawl/extractor.py | 16 ++++++++-- api/services/datasource_provider_service.py | 8 ++--- api/services/website_service.py | 29 ++++++++++++------- 9 files changed, 68 insertions(+), 32 deletions(-) diff --git a/api/commands.py b/api/commands.py index 85856ca2a8..1eacdb542f 100644 --- a/api/commands.py +++ b/api/commands.py @@ -14,7 +14,6 @@ from configs import dify_config from constants.languages import languages from core.helper import encrypter from core.plugin.entities.plugin import DatasourceProviderID, PluginInstallationSource, ToolProviderID -from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.plugin import PluginInstaller from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType @@ -38,7 +37,6 @@ from models.provider import Provider, ProviderModel from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService -from services.auth import firecrawl from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration @@ -1255,6 +1253,7 @@ def setup_datasource_oauth_client(provider, client_params): click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green")) click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green")) + @click.command("transform-datasource-credentials", help="Transform datasource credentials.") def transform_datasource_credentials(): """ @@ -1273,7 +1272,6 @@ def transform_datasource_credentials(): oauth_credential_type = CredentialType.OAUTH2 api_key_credential_type = CredentialType.API_KEY - # deal notion credentials deal_notion_count = 0 notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all() @@ -1429,5 +1427,7 @@ def transform_datasource_credentials(): click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) return click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green")) - click.echo(click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green")) + click.echo( + click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green") + ) click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green")) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index ecf5d7d336..cfe9e7966f 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,5 +1,6 @@ import json -from typing import Generator, cast +from collections.abc import Generator +from typing import cast from flask import request from flask_login import current_user @@ -20,7 +21,6 @@ from fields.data_source_fields import integrate_list_fields, integrate_notion_in from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import DataSourceOauthBinding, Document -from models.oauth import DatasourceProvider from services.dataset_service import DatasetService, DocumentService from services.datasource_provider_service import DatasourceProviderService from tasks.document_indexing_sync_task import document_indexing_sync_task diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 3e38c9153c..d0a4a9353f 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -10,6 +10,7 @@ class NotionInfo(BaseModel): """ Notion import info. """ + credential_id: Optional[str] = None notion_workspace_id: str notion_obj_id: str diff --git a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py index cf5ede4daa..25454fa1d9 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_web_extractor.py @@ -1,4 +1,5 @@ from typing import Optional + from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from services.website_service import WebsiteService @@ -16,8 +17,15 @@ class FirecrawlWebExtractor(BaseExtractor): only_main_content: Only return the main content of the page excluding headers, navs, footers, etc. """ - def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True, - credential_id: Optional[str] = None): + def __init__( + self, + url: str, + job_id: str, + tenant_id: str, + mode: str = "crawl", + only_main_content: bool = True, + credential_id: Optional[str] = None, + ): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id @@ -30,7 +38,9 @@ class FirecrawlWebExtractor(BaseExtractor): """Extract content from the URL.""" documents = [] if self.mode == "crawl": - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "firecrawl", self._url, self.tenant_id, self.credential_id) + crawl_data = WebsiteService.get_crawl_url_data( + self.job_id, "firecrawl", self._url, self.tenant_id, self.credential_id + ) if crawl_data is None: return [] document = Document( diff --git a/api/core/rag/extractor/jina_reader_extractor.py b/api/core/rag/extractor/jina_reader_extractor.py index a74fb203e2..88c240393f 100644 --- a/api/core/rag/extractor/jina_reader_extractor.py +++ b/api/core/rag/extractor/jina_reader_extractor.py @@ -1,4 +1,5 @@ from typing import Optional + from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from services.website_service import WebsiteService @@ -9,8 +10,15 @@ class JinaReaderWebExtractor(BaseExtractor): Crawl and scrape websites and return content in clean llm-ready markdown. """ - def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = False, - credential_id: Optional[str] = None): + def __init__( + self, + url: str, + job_id: str, + tenant_id: str, + mode: str = "crawl", + only_main_content: bool = False, + credential_id: Optional[str] = None, + ): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id @@ -23,7 +31,9 @@ class JinaReaderWebExtractor(BaseExtractor): """Extract content from the URL.""" documents = [] if self.mode == "crawl": - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "jinareader", self._url, self.tenant_id, self.credential_id) + crawl_data = WebsiteService.get_crawl_url_data( + self.job_id, "jinareader", self._url, self.tenant_id, self.credential_id + ) if crawl_data is None: return [] document = Document( diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 73bf7c81fb..b84e7a8c6a 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -9,8 +9,6 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import Document as DocumentModel -from models.oauth import DatasourceProvider -from models.source import DataSourceOauthBinding from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) diff --git a/api/core/rag/extractor/watercrawl/extractor.py b/api/core/rag/extractor/watercrawl/extractor.py index 5559917cc5..e5805d1b64 100644 --- a/api/core/rag/extractor/watercrawl/extractor.py +++ b/api/core/rag/extractor/watercrawl/extractor.py @@ -1,4 +1,5 @@ from typing import Optional + from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from services.website_service import WebsiteService @@ -17,8 +18,15 @@ class WaterCrawlWebExtractor(BaseExtractor): only_main_content: Only return the main content of the page excluding headers, navs, footers, etc. """ - def __init__(self, url: str, job_id: str, tenant_id: str, mode: str = "crawl", only_main_content: bool = True, - credential_id: Optional[str] = None): + def __init__( + self, + url: str, + job_id: str, + tenant_id: str, + mode: str = "crawl", + only_main_content: bool = True, + credential_id: Optional[str] = None, + ): """Initialize with url, api_key, base_url and mode.""" self._url = url self.job_id = job_id @@ -31,7 +39,9 @@ class WaterCrawlWebExtractor(BaseExtractor): """Extract content from the URL.""" documents = [] if self.mode == "crawl": - crawl_data = WebsiteService.get_crawl_url_data(self.job_id, "watercrawl", self._url, self.tenant_id, self.credential_id) + crawl_data = WebsiteService.get_crawl_url_data( + self.job_id, "watercrawl", self._url, self.tenant_id, self.credential_id + ) if crawl_data is None: return [] document = Document( diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 3f85310ab2..a3600db4e5 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -56,15 +56,15 @@ class DatasourceProviderService: return {} return datasource_provider.encrypted_credentials - def get_real_credential_by_id(self, tenant_id: str, credential_id: str, provider: str, plugin_id: str) -> dict[str, Any]: + def get_real_credential_by_id( + self, tenant_id: str, credential_id: str, provider: str, plugin_id: str + ) -> dict[str, Any]: """ get credential by id """ with Session(db.engine) as session: datasource_provider = ( - session.query(DatasourceProvider) - .filter_by(tenant_id=tenant_id, id=credential_id) - .first() + session.query(DatasourceProvider).filter_by(tenant_id=tenant_id, id=credential_id).first() ) if not datasource_provider: return {} diff --git a/api/services/website_service.py b/api/services/website_service.py index a0bee311e7..a12128e248 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -120,7 +120,9 @@ class WebsiteService: """Service class for website crawling operations using different providers.""" @classmethod - def _get_credentials_and_config(cls, tenant_id: str, provider: str, credential_id: Optional[str] = None) -> tuple[Any, Any]: + def _get_credentials_and_config( + cls, tenant_id: str, provider: str, credential_id: Optional[str] = None + ) -> tuple[Any, Any]: """Get and validate credentials for a provider.""" if credential_id: if provider == "firecrawl": @@ -164,7 +166,9 @@ class WebsiteService: """Crawl a URL using the specified provider with typed request.""" request = api_request.to_crawl_request() - _, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider, api_request.credential_id) + _, config = cls._get_credentials_and_config( + current_user.current_tenant_id, request.provider, api_request.credential_id + ) if api_request.credential_id: api_key = _ else: @@ -258,9 +262,9 @@ class WebsiteService: @classmethod def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]: """Get crawl status using typed request.""" - _, config = cls._get_credentials_and_config(current_user.current_tenant_id, - api_request.provider, - api_request.credential_id) + _, config = cls._get_credentials_and_config( + current_user.current_tenant_id, api_request.provider, api_request.credential_id + ) if api_request.credential_id: api_key = _ else: @@ -337,7 +341,9 @@ class WebsiteService: return crawl_status_data @classmethod - def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str, credential_id: Optional[str] = None) -> dict[str, Any] | None: + def get_crawl_url_data( + cls, job_id: str, provider: str, url: str, tenant_id: str, credential_id: Optional[str] = None + ) -> dict[str, Any] | None: _, config = cls._get_credentials_and_config(tenant_id, provider, credential_id) if credential_id: api_key = _ @@ -412,13 +418,14 @@ class WebsiteService: return None @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool, - credential_id: Optional[str] = None) -> dict[str, Any]: + def get_scrape_url_data( + cls, provider: str, url: str, tenant_id: str, only_main_content: bool, credential_id: Optional[str] = None + ) -> dict[str, Any]: request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content) - _, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, - provider=request.provider, - credential_id=credential_id) + _, config = cls._get_credentials_and_config( + tenant_id=request.tenant_id, provider=request.provider, credential_id=credential_id + ) if credential_id: api_key = _ else: