feat(datasource): change datasource result type to event-stream

This commit is contained in:
Dongyu Li 2025-06-19 11:10:24 +08:00
parent 02ae479636
commit 82d0a70cb4
4 changed files with 115 additions and 84 deletions

View File

@ -213,10 +213,11 @@ class OnlineDocumentPage(BaseModel):
"""
page_id: str = Field(..., description="The page id")
page_title: str = Field(..., description="The page title")
page_name: str = Field(..., description="The page title")
page_icon: Optional[dict] = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: Optional[str] = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):

View File

@ -135,7 +135,7 @@ class PluginDatasourceManager(BasePluginClient):
datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream(
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
OnlineDocumentPagesMessage,
@ -153,7 +153,6 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json",
},
)
yield from response
def get_online_document_page_content(
self,

View File

@ -11,11 +11,16 @@ class DatasourceStreamEvent(Enum):
"""
PROCESSING = "datasource_processing"
COMPLETED = "datasource_completed"
ERROR = "datasource_error"
class BaseDatasourceEvent(BaseModel):
pass
class DatasourceErrorEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.ERROR.value
error: str = Field(..., description="error message")
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str,Any] | list = Field(..., description="result")

View File

@ -1,4 +1,5 @@
import json
import logging
import re
import threading
import time
@ -21,7 +22,12 @@ from core.datasource.entities.datasource_entities import (
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent
from core.rag.entities.event import (
BaseDatasourceEvent,
DatasourceCompletedEvent,
DatasourceErrorEvent,
DatasourceProcessingEvent,
)
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
@ -61,6 +67,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import (
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
logger = logging.getLogger(__name__)
class RagPipelineService:
@classmethod
@ -430,93 +437,112 @@ class RagPipelineService:
"""
Run published workflow datasource
"""
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
try:
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not workflow:
raise ValueError("Workflow not initialized")
# run draft workflow node
datasource_node_data = None
start_at = time.perf_counter()
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
# run draft workflow node
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
datasource_node_data = datasource_node.get("data", {})
break
if not datasource_node_data:
raise ValueError("Datasource node data not found")
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
for key, value in datasource_parameters.items():
if not user_inputs.get(key):
user_inputs[key] = value["value"]
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.datasource_manager import DatasourceManager
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{datasource_node_data.get('plugin_id')}/{datasource_node_data.get('provider_name')}",
datasource_name=datasource_node_data.get("datasource_name"),
tenant_id=pipeline.tenant_id,
datasource_type=DatasourceProviderType(datasource_type),
)
datasource_provider_service = DatasourceProviderService()
credentials = datasource_provider_service.get_real_datasource_credentials(
tenant_id=pipeline.tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
if credentials:
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
online_document_result: Generator[OnlineDocumentPagesMessage, None, None] = (
datasource_runtime.get_online_document_pages(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
start_time = time.time()
start_event = DatasourceProcessingEvent(
total=0,
completed=0,
)
yield start_event.model_dump()
try:
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
data=message.result,
time_consuming=round(end_time - start_time, 2)
)
yield online_document_event.model_dump()
except Exception as e:
logger.exception("Error during online document.")
yield DatasourceErrorEvent(
error=str(e)
).model_dump()
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
)
start_time = time.time()
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
data=message.result,
time_consuming=round(end_time - start_time, 2)
)
yield online_document_event.model_dump()
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
start_time = time.time()
try:
for message in website_crawl_result:
end_time = time.time()
if message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming=round(end_time - start_time, 2)
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
)
yield crawl_event.model_dump()
except Exception as e:
print(str(e))
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
))
start_time = time.time()
try:
for message in website_crawl_result:
end_time = time.time()
if message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming=round(end_time - start_time, 2)
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
)
yield crawl_event.model_dump()
except Exception as e:
logger.exception("Error during website crawl.")
yield DatasourceErrorEvent(
error=str(e)
).model_dump()
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
except Exception as e:
logger.exception("Error in run_datasource_workflow_node.")
yield DatasourceErrorEvent(
error=str(e)
).model_dump()
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]