feat(datasource): change datasource result type to event-stream

This commit is contained in:
Dongyu Li 2025-06-19 11:10:24 +08:00
parent 02ae479636
commit 82d0a70cb4
4 changed files with 115 additions and 84 deletions

View File

@ -213,10 +213,11 @@ class OnlineDocumentPage(BaseModel):
"""
page_id: str = Field(..., description="The page id")
page_title: str = Field(..., description="The page title")
page_name: str = Field(..., description="The page title")
page_icon: Optional[dict] = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: Optional[str] = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):

View File

@ -135,7 +135,7 @@ class PluginDatasourceManager(BasePluginClient):
datasource_provider_id = GenericProviderID(datasource_provider)
response = self._request_with_plugin_daemon_response_stream(
return self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
OnlineDocumentPagesMessage,
@ -153,7 +153,6 @@ class PluginDatasourceManager(BasePluginClient):
"Content-Type": "application/json",
},
)
yield from response
def get_online_document_page_content(
self,

View File

@ -11,11 +11,16 @@ class DatasourceStreamEvent(Enum):
"""
PROCESSING = "datasource_processing"
COMPLETED = "datasource_completed"
ERROR = "datasource_error"
class BaseDatasourceEvent(BaseModel):
pass
class DatasourceErrorEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.ERROR.value
error: str = Field(..., description="error message")
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str,Any] | list = Field(..., description="result")

View File

@ -1,4 +1,5 @@
import json
import logging
import re
import threading
import time
@ -21,7 +22,12 @@ from core.datasource.entities.datasource_entities import (
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent
from core.rag.entities.event import (
BaseDatasourceEvent,
DatasourceCompletedEvent,
DatasourceErrorEvent,
DatasourceProcessingEvent,
)
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
@ -61,6 +67,7 @@ from services.entities.knowledge_entities.rag_pipeline_entities import (
from services.errors.app import WorkflowHashNotEqualError
from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory
logger = logging.getLogger(__name__)
class RagPipelineService:
@classmethod
@ -430,6 +437,7 @@ class RagPipelineService:
"""
Run published workflow datasource
"""
try:
if is_published:
# fetch published workflow by app_model
workflow = self.get_published_workflow(pipeline=pipeline)
@ -440,7 +448,6 @@ class RagPipelineService:
# run draft workflow node
datasource_node_data = None
start_at = time.perf_counter()
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
if datasource_node.get("id") == node_id:
@ -481,6 +488,12 @@ class RagPipelineService:
)
)
start_time = time.time()
start_event = DatasourceProcessingEvent(
total=0,
completed=0,
)
yield start_event.model_dump()
try:
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
@ -488,14 +501,19 @@ class RagPipelineService:
time_consuming=round(end_time - start_time, 2)
)
yield online_document_event.model_dump()
except Exception as e:
logger.exception("Error during online document.")
yield DatasourceErrorEvent(
error=str(e)
).model_dump()
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = datasource_runtime.get_website_crawl(
website_crawl_result: Generator[WebsiteCrawlMessage, None, None] = (
datasource_runtime.get_website_crawl(
user_id=account.id,
datasource_parameters=user_inputs,
provider_type=datasource_runtime.datasource_provider_type(),
)
))
start_time = time.time()
try:
for message in website_crawl_result:
@ -514,9 +532,17 @@ class RagPipelineService:
)
yield crawl_event.model_dump()
except Exception as e:
print(str(e))
logger.exception("Error during website crawl.")
yield DatasourceErrorEvent(
error=str(e)
).model_dump()
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")
except Exception as e:
logger.exception("Error in run_datasource_workflow_node.")
yield DatasourceErrorEvent(
error=str(e)
).model_dump()
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]