mirror of
https://github.com/microsoft/graphrag.git
synced 2025-07-03 07:04:19 +00:00
55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
![]() |
# Copyright (c) 2024 Microsoft Corporation.
|
||
|
# Licensed under the MIT License
|
||
|
|
||
|
"""Tests for pipeline state passthrough."""
|
||
|
|
||
|
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||
|
from graphrag.index.run.utils import create_run_context
|
||
|
from graphrag.index.typing.context import PipelineRunContext
|
||
|
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||
|
from graphrag.index.workflows.factory import PipelineFactory
|
||
|
from tests.verbs.util import DEFAULT_MODEL_CONFIG
|
||
|
|
||
|
|
||
|
async def run_workflow_1( # noqa: RUF029
|
||
|
_config: GraphRagConfig, context: PipelineRunContext
|
||
|
):
|
||
|
context.state["count"] = 1
|
||
|
return WorkflowFunctionOutput(result=None)
|
||
|
|
||
|
|
||
|
async def run_workflow_2( # noqa: RUF029
|
||
|
_config: GraphRagConfig, context: PipelineRunContext
|
||
|
):
|
||
|
context.state["count"] += 1
|
||
|
return WorkflowFunctionOutput(result=None)
|
||
|
|
||
|
|
||
|
async def test_pipeline_state():
|
||
|
# checks that we can update the arbitrary state block within the pipeline run context
|
||
|
PipelineFactory.register("workflow_1", run_workflow_1)
|
||
|
PipelineFactory.register("workflow_2", run_workflow_2)
|
||
|
|
||
|
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||
|
config.workflows = ["workflow_1", "workflow_2"]
|
||
|
context = create_run_context()
|
||
|
|
||
|
for _, fn in PipelineFactory.create_pipeline(config).run():
|
||
|
await fn(config, context)
|
||
|
|
||
|
assert context.state["count"] == 2
|
||
|
|
||
|
|
||
|
async def test_pipeline_existing_state():
|
||
|
PipelineFactory.register("workflow_2", run_workflow_2)
|
||
|
|
||
|
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
||
|
config.workflows = ["workflow_2"]
|
||
|
context = create_run_context(state={"count": 4})
|
||
|
|
||
|
for _, fn in PipelineFactory.create_pipeline(config).run():
|
||
|
await fn(config, context)
|
||
|
|
||
|
assert context.state["count"] == 5
|