graphrag/tests/verbs/test_pipeline_state.py
Nathan Evans bd06d8b4f0
Context property bag ("state") (#1774)
* Add pipeline state property bag to run context

* Move state creation out of context util

* Move callbacks into PipelineRunContext

* Semver

* Rename state.json to context.json to avoid confusion with stats.json

* Expand smoke test row count

* Add util to create storage and cache
2025-02-28 09:31:48 -08:00

55 lines
1.8 KiB
Python

# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Tests for pipeline state passthrough."""
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.utils import create_run_context
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.workflows.factory import PipelineFactory
from tests.verbs.util import DEFAULT_MODEL_CONFIG
async def run_workflow_1( # noqa: RUF029
_config: GraphRagConfig, context: PipelineRunContext
):
context.state["count"] = 1
return WorkflowFunctionOutput(result=None)
async def run_workflow_2( # noqa: RUF029
_config: GraphRagConfig, context: PipelineRunContext
):
context.state["count"] += 1
return WorkflowFunctionOutput(result=None)
async def test_pipeline_state():
# checks that we can update the arbitrary state block within the pipeline run context
PipelineFactory.register("workflow_1", run_workflow_1)
PipelineFactory.register("workflow_2", run_workflow_2)
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
config.workflows = ["workflow_1", "workflow_2"]
context = create_run_context()
for _, fn in PipelineFactory.create_pipeline(config).run():
await fn(config, context)
assert context.state["count"] == 2
async def test_pipeline_existing_state():
PipelineFactory.register("workflow_2", run_workflow_2)
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
config.workflows = ["workflow_2"]
context = create_run_context(state={"count": 4})
for _, fn in PipelineFactory.create_pipeline(config).run():
await fn(config, context)
assert context.state["count"] == 5