mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00

* Add pipeline state property bag to run context * Move state creation out of context util * Move callbacks into PipelineRunContext * Semver * Rename state.json to context.json to avoid confusion with stats.json * Expand smoke test row count * Add util to create storage and cache
55 lines
1.8 KiB
Python
55 lines
1.8 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
"""Tests for pipeline state passthrough."""
|
|
|
|
from graphrag.config.create_graphrag_config import create_graphrag_config
|
|
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
|
from graphrag.index.run.utils import create_run_context
|
|
from graphrag.index.typing.context import PipelineRunContext
|
|
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
|
from graphrag.index.workflows.factory import PipelineFactory
|
|
from tests.verbs.util import DEFAULT_MODEL_CONFIG
|
|
|
|
|
|
async def run_workflow_1( # noqa: RUF029
|
|
_config: GraphRagConfig, context: PipelineRunContext
|
|
):
|
|
context.state["count"] = 1
|
|
return WorkflowFunctionOutput(result=None)
|
|
|
|
|
|
async def run_workflow_2( # noqa: RUF029
|
|
_config: GraphRagConfig, context: PipelineRunContext
|
|
):
|
|
context.state["count"] += 1
|
|
return WorkflowFunctionOutput(result=None)
|
|
|
|
|
|
async def test_pipeline_state():
|
|
# checks that we can update the arbitrary state block within the pipeline run context
|
|
PipelineFactory.register("workflow_1", run_workflow_1)
|
|
PipelineFactory.register("workflow_2", run_workflow_2)
|
|
|
|
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
|
config.workflows = ["workflow_1", "workflow_2"]
|
|
context = create_run_context()
|
|
|
|
for _, fn in PipelineFactory.create_pipeline(config).run():
|
|
await fn(config, context)
|
|
|
|
assert context.state["count"] == 2
|
|
|
|
|
|
async def test_pipeline_existing_state():
|
|
PipelineFactory.register("workflow_2", run_workflow_2)
|
|
|
|
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
|
|
config.workflows = ["workflow_2"]
|
|
context = create_run_context(state={"count": 4})
|
|
|
|
for _, fn in PipelineFactory.create_pipeline(config).run():
|
|
await fn(config, context)
|
|
|
|
assert context.state["count"] == 5
|