dify/api/core/datasource/datasource_manager.py

101 lines
4.3 KiB
Python
Raw Normal View History

2025-04-25 14:56:22 +08:00
import logging
from threading import Lock
from typing import Union
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
2025-05-23 15:55:41 +08:00
from core.datasource.entities.datasource_entities import DatasourceProviderType
2025-05-15 16:07:17 +08:00
from core.datasource.errors import DatasourceProviderNotFoundError
2025-05-23 15:55:41 +08:00
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
from core.plugin.impl.datasource import PluginDatasourceManager
2025-04-25 14:56:22 +08:00
logger = logging.getLogger(__name__)
class DatasourceManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
2025-05-23 15:55:41 +08:00
def get_datasource_plugin_provider(
cls, provider: str, tenant_id: str, datasource_type: DatasourceProviderType
) -> DatasourcePluginProviderController:
2025-04-25 14:56:22 +08:00
"""
get the datasource plugin provider
"""
# check if context is set
try:
contexts.datasource_plugin_providers.get()
except LookupError:
contexts.datasource_plugin_providers.set({})
contexts.datasource_plugin_providers_lock.set(Lock())
with contexts.datasource_plugin_providers_lock.get():
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
if provider in datasource_plugin_providers:
return datasource_plugin_providers[provider]
manager = PluginDatasourceManager()
2025-05-15 16:07:17 +08:00
provider_entity = manager.fetch_datasource_provider(tenant_id, provider)
2025-04-25 14:56:22 +08:00
if not provider_entity:
2025-05-15 16:07:17 +08:00
raise DatasourceProviderNotFoundError(f"plugin provider {provider} not found")
2025-04-25 14:56:22 +08:00
2025-06-03 19:02:57 +08:00
match datasource_type:
2025-05-23 15:55:41 +08:00
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.WEBSITE_CRAWL:
controller = WebsiteCrawlDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case DatasourceProviderType.LOCAL_FILE:
controller = LocalFileDatasourcePluginProviderController(
entity=provider_entity.declaration,
plugin_id=provider_entity.plugin_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
tenant_id=tenant_id,
)
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
2025-04-25 14:56:22 +08:00
datasource_plugin_providers[provider] = controller
return controller
@classmethod
def get_datasource_runtime(
cls,
provider_id: str,
datasource_name: str,
tenant_id: str,
2025-05-23 15:55:41 +08:00
datasource_type: DatasourceProviderType,
2025-04-25 14:56:22 +08:00
) -> DatasourcePlugin:
"""
get the datasource runtime
:param provider_type: the type of the provider
:param provider_id: the id of the provider
:param datasource_name: the name of the datasource
:param tenant_id: the tenant id
:return: the datasource plugin
"""
2025-05-23 15:55:41 +08:00
return cls.get_datasource_plugin_provider(
provider_id,
tenant_id,
datasource_type,
).get_datasource(datasource_name)