Merge branch 'feat/r2' into deploy/rag-dev

This commit is contained in:
Dongyu Li 2025-06-17 18:24:52 +08:00
commit 2d01b1a808
4 changed files with 50 additions and 22 deletions

View File

@ -406,10 +406,10 @@ class RagPipelinePublishedDatasourceNodeRunApi(Resource):
args = parser.parse_args()
inputs = args.get("inputs")
if inputs == None:
if inputs is None:
raise ValueError("missing inputs")
datasource_type = args.get("datasource_type")
if datasource_type == None:
if datasource_type is None:
raise ValueError("missing datasource_type")
rag_pipeline_service = RagPipelineService()

View File

@ -0,0 +1,30 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field
class DatasourceStreamEvent(Enum):
"""
Datasource Stream event
"""
PROCESSING = "processing"
COMPLETED = "completed"
class BaseDatasourceEvent(BaseModel):
pass
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str,Any] | list = Field(..., description="result")
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")
time_consuming: Optional[float] = Field(..., description="time consuming")
class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")

View File

@ -275,10 +275,3 @@ class AgentLogEvent(BaseAgentEvent):
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent | BaseLoopEvent
class DatasourceRunEvent(BaseModel):
status: str = Field(..., description="status")
data: Mapping[str,Any] | list = Field(..., description="result")
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")
time_consuming: Optional[float] = Field(..., description="time consuming")

View File

@ -21,6 +21,7 @@ from core.datasource.entities.datasource_entities import (
)
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
from core.rag.entities.event import BaseDatasourceEvent, DatasourceCompletedEvent, DatasourceProcessingEvent
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
@ -30,7 +31,7 @@ from core.workflow.entities.workflow_node_execution import (
)
from core.workflow.enums import SystemVariableKey
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import DatasourceRunEvent, InNodeEvent
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event.event import RunCompletedEvent
@ -486,7 +487,7 @@ class RagPipelineService:
def run_datasource_workflow_node(
self, pipeline: Pipeline, node_id: str, user_inputs: dict, account: Account, datasource_type: str,
is_published: bool
) -> Generator[str, None, None]:
) -> Generator[BaseDatasourceEvent, None, None]:
"""
Run published workflow datasource
"""
@ -542,12 +543,11 @@ class RagPipelineService:
start_time = time.time()
for message in online_document_result:
end_time = time.time()
online_document_event = DatasourceRunEvent(
status="completed",
online_document_event = DatasourceCompletedEvent(
data=message.result,
time_consuming=round(end_time - start_time, 2)
)
yield json.dumps(online_document_event.model_dump())
yield online_document_event.model_dump()
case DatasourceProviderType.WEBSITE_CRAWL:
datasource_runtime = cast(WebsiteCrawlDatasourcePlugin, datasource_runtime)
@ -559,14 +559,19 @@ class RagPipelineService:
start_time = time.time()
for message in website_crawl_result:
end_time = time.time()
crawl_event = DatasourceRunEvent(
status=message.result.status,
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming = round(end_time - start_time, 2)
)
yield json.dumps(crawl_event.model_dump())
if message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list,
total=message.result.total,
completed=message.result.completed,
time_consuming=round(end_time - start_time, 2)
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
)
yield crawl_event.model_dump()
case _:
raise ValueError(f"Unsupported datasource provider: {datasource_runtime.datasource_provider_type}")