diff --git a/.semversioner/next-release/patch-20241022221802983784.json b/.semversioner/next-release/patch-20241022221802983784.json new file mode 100644 index 00000000..6b73219b --- /dev/null +++ b/.semversioner/next-release/patch-20241022221802983784.json @@ -0,0 +1,4 @@ +{ + "type": "patch", + "description": "Add runtime-only storage option." +} diff --git a/graphrag/index/context.py b/graphrag/index/context.py index 48c1cd80..94934e49 100644 --- a/graphrag/index/context.py +++ b/graphrag/index/context.py @@ -34,7 +34,11 @@ class PipelineRunContext: stats: PipelineRunStats storage: PipelineStorage + "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." cache: PipelineCache + "Cache instance for reading previous LLM responses." + runtime_storage: PipelineStorage + "Runtime only storage for pipeline verbs to use. Items written here will only live in memory during the current run." # TODO: For now, just has the same props available to it diff --git a/graphrag/index/run/profiling.py b/graphrag/index/run/profiling.py index d1a54a66..6121b9ea 100644 --- a/graphrag/index/run/profiling.py +++ b/graphrag/index/run/profiling.py @@ -69,7 +69,3 @@ async def _write_workflow_stats( await _save_profiler_stats( storage, workflow.name, workflow_result.memory_profile ) - - log.debug( - "first row of %s => %s", workflow.name, workflow.output().iloc[0].to_json() - ) diff --git a/graphrag/index/run/run.py b/graphrag/index/run/run.py index 980e804c..6f3e6721 100644 --- a/graphrag/index/run/run.py +++ b/graphrag/index/run/run.py @@ -32,8 +32,8 @@ from graphrag.index.run.utils import ( _apply_substitutions, _create_input, _create_reporter, - _create_run_context, _validate_dataset, + create_run_context, ) from graphrag.index.run.workflow import ( _create_callback_chain, @@ -200,7 +200,7 @@ async def run_pipeline( """ start_time = time.time() - context = _create_run_context(storage=storage, cache=cache, stats=None) + context = create_run_context(storage=storage, cache=cache, stats=None) progress_reporter = progress_reporter or NullProgressReporter() callbacks = callbacks or ConsoleWorkflowCallbacks() diff --git a/graphrag/index/run/utils.py b/graphrag/index/run/utils.py index 0b4e9a80..49650d62 100644 --- a/graphrag/index/run/utils.py +++ b/graphrag/index/run/utils.py @@ -103,7 +103,7 @@ def _apply_substitutions(config: PipelineConfig, run_id: str) -> PipelineConfig: return config -def _create_run_context( +def create_run_context( storage: PipelineStorage | None, cache: PipelineCache | None, stats: PipelineRunStats | None, @@ -113,4 +113,5 @@ def _create_run_context( stats=stats or PipelineRunStats(), cache=cache or InMemoryCache(), storage=storage or MemoryPipelineStorage(), + runtime_storage=MemoryPipelineStorage(), ) diff --git a/graphrag/index/run/workflow.py b/graphrag/index/run/workflow.py index aa288262..e8f001a2 100644 --- a/graphrag/index/run/workflow.py +++ b/graphrag/index/run/workflow.py @@ -39,7 +39,18 @@ async def _inject_workflow_data_dependencies( log.info("dependencies for %s: %s", workflow.name, deps) for id in deps: workflow_id = f"workflow:{id}" - table = await _load_table_from_storage(f"{id}.parquet", storage) + try: + table = await _load_table_from_storage(f"{id}.parquet", storage) + except ValueError: + # our workflows now allow transient tables, and we avoid putting those in primary storage + # however, we need to keep the table in the dependency list for proper execution order + # this allows us to catch missing table errors but emit a warning for pipeline users who may genuinely have an error (which we expect to be very rare) + # todo: this issue will resolve itself if we remove DataShaper completely + log.warning( + "Dependency table %s not found in storage: it may be a runtime-only in-memory table. If you see further errors, this may be an actual problem.", + id, + ) + table = pd.DataFrame() workflow.add_table(workflow_id, table) diff --git a/graphrag/index/storage/blob_pipeline_storage.py b/graphrag/index/storage/blob_pipeline_storage.py index 5494c5a6..bdf25c99 100644 --- a/graphrag/index/storage/blob_pipeline_storage.py +++ b/graphrag/index/storage/blob_pipeline_storage.py @@ -276,6 +276,11 @@ class BlobPipelineStorage(PipelineStorage): self._storage_account_blob_url, ) + def keys(self) -> list[str]: + """Return the keys in the storage.""" + msg = "Blob storage does yet not support listing keys." + raise NotImplementedError(msg) + def _keyname(self, key: str) -> str: """Get the key name.""" return str(Path(self._path_prefix) / key) diff --git a/graphrag/index/storage/file_pipeline_storage.py b/graphrag/index/storage/file_pipeline_storage.py index ccb06d1f..a3a18cf4 100644 --- a/graphrag/index/storage/file_pipeline_storage.py +++ b/graphrag/index/storage/file_pipeline_storage.py @@ -143,6 +143,10 @@ class FilePipelineStorage(PipelineStorage): return self return FilePipelineStorage(str(Path(self._root_dir) / Path(name))) + def keys(self) -> list[str]: + """Return the keys in the storage.""" + return os.listdir(self._root_dir) + def join_path(file_path: str, file_name: str) -> Path: """Join a path and a file. Independent of the OS.""" diff --git a/graphrag/index/storage/memory_pipeline_storage.py b/graphrag/index/storage/memory_pipeline_storage.py index ed490983..3f9f9c9b 100644 --- a/graphrag/index/storage/memory_pipeline_storage.py +++ b/graphrag/index/storage/memory_pipeline_storage.py @@ -32,7 +32,7 @@ class MemoryPipelineStorage(FilePipelineStorage): ------- - output - The value for the given key. """ - return self._storage.get(key) or await super().get(key, as_bytes, encoding) + return self._storage.get(key) async def set(self, key: str, value: Any, encoding: str | None = None) -> None: """Set the value for the given key. @@ -53,7 +53,7 @@ class MemoryPipelineStorage(FilePipelineStorage): ------- - output - True if the key exists in the storage, False otherwise. """ - return key in self._storage or await super().has(key) + return key in self._storage async def delete(self, key: str) -> None: """Delete the given key from the storage. @@ -69,7 +69,7 @@ class MemoryPipelineStorage(FilePipelineStorage): def child(self, name: str | None) -> "PipelineStorage": """Create a child storage instance.""" - return self + return MemoryPipelineStorage() def keys(self) -> list[str]: """Return the keys in the storage.""" diff --git a/graphrag/index/storage/pipeline_storage.py b/graphrag/index/storage/pipeline_storage.py index 08ccb409..554c2ffd 100644 --- a/graphrag/index/storage/pipeline_storage.py +++ b/graphrag/index/storage/pipeline_storage.py @@ -76,3 +76,7 @@ class PipelineStorage(metaclass=ABCMeta): @abstractmethod def child(self, name: str | None) -> "PipelineStorage": """Create a child storage instance.""" + + @abstractmethod + def keys(self) -> list[str]: + """List all keys in the storage.""" diff --git a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py index 2ab85b01..dea94792 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py +++ b/graphrag/index/workflows/v1/subflows/create_base_entity_graph.py @@ -5,12 +5,10 @@ from typing import Any, cast -import pandas as pd from datashaper import ( AsyncType, Table, VerbCallbacks, - VerbInput, verb, ) from datashaper.table_store.types import VerbResult, create_verb_result @@ -27,10 +25,10 @@ from graphrag.index.storage import PipelineStorage treats_input_tables_as_immutable=True, ) async def create_base_entity_graph( - input: VerbInput, callbacks: VerbCallbacks, cache: PipelineCache, storage: PipelineStorage, + runtime_storage: PipelineStorage, text_column: str, id_column: str, clustering_strategy: dict[str, Any], @@ -48,10 +46,10 @@ async def create_base_entity_graph( **_kwargs: dict, ) -> VerbResult: """All the steps to create the base entity graph.""" - source = cast(pd.DataFrame, input.get_input()) + text_units = await runtime_storage.get("base_text_units") output = await create_base_entity_graph_flow( - source, + text_units, callbacks, cache, storage, diff --git a/graphrag/index/workflows/v1/subflows/create_base_text_units.py b/graphrag/index/workflows/v1/subflows/create_base_text_units.py index 18c65008..2518111d 100644 --- a/graphrag/index/workflows/v1/subflows/create_base_text_units.py +++ b/graphrag/index/workflows/v1/subflows/create_base_text_units.py @@ -17,12 +17,14 @@ from datashaper.table_store.types import VerbResult, create_verb_result from graphrag.index.flows.create_base_text_units import ( create_base_text_units as create_base_text_units_flow, ) +from graphrag.index.storage import PipelineStorage @verb(name="create_base_text_units", treats_input_tables_as_immutable=True) -def create_base_text_units( +async def create_base_text_units( input: VerbInput, callbacks: VerbCallbacks, + runtime_storage: PipelineStorage, chunk_column_name: str, n_tokens_column_name: str, chunk_by_columns: list[str], @@ -41,9 +43,11 @@ def create_base_text_units( chunk_strategy=chunk_strategy, ) + await runtime_storage.set("base_text_units", output) + return create_verb_result( cast( Table, - output, + pd.DataFrame(), ) ) diff --git a/graphrag/index/workflows/v1/subflows/create_final_covariates.py b/graphrag/index/workflows/v1/subflows/create_final_covariates.py index d0812bc5..f6998770 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_covariates.py +++ b/graphrag/index/workflows/v1/subflows/create_final_covariates.py @@ -5,12 +5,10 @@ from typing import Any, cast -import pandas as pd from datashaper import ( AsyncType, Table, VerbCallbacks, - VerbInput, verb, ) from datashaper.table_store.types import VerbResult, create_verb_result @@ -19,13 +17,14 @@ from graphrag.index.cache import PipelineCache from graphrag.index.flows.create_final_covariates import ( create_final_covariates as create_final_covariates_flow, ) +from graphrag.index.storage import PipelineStorage @verb(name="create_final_covariates", treats_input_tables_as_immutable=True) async def create_final_covariates( - input: VerbInput, callbacks: VerbCallbacks, cache: PipelineCache, + runtime_storage: PipelineStorage, column: str, covariate_type: str, extraction_strategy: dict[str, Any] | None, @@ -35,10 +34,10 @@ async def create_final_covariates( **_kwargs: dict, ) -> VerbResult: """All the steps to extract and format covariates.""" - source = cast(pd.DataFrame, input.get_input()) + text_units = await runtime_storage.get("base_text_units") output = await create_final_covariates_flow( - source, + text_units, callbacks, cache, column, diff --git a/graphrag/index/workflows/v1/subflows/create_final_text_units.py b/graphrag/index/workflows/v1/subflows/create_final_text_units.py index 14bd8399..37ddb106 100644 --- a/graphrag/index/workflows/v1/subflows/create_final_text_units.py +++ b/graphrag/index/workflows/v1/subflows/create_final_text_units.py @@ -19,6 +19,7 @@ from graphrag.index.cache import PipelineCache from graphrag.index.flows.create_final_text_units import ( create_final_text_units as create_final_text_units_flow, ) +from graphrag.index.storage import PipelineStorage from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table @@ -27,11 +28,12 @@ async def create_final_text_units( input: VerbInput, callbacks: VerbCallbacks, cache: PipelineCache, + runtime_storage: PipelineStorage, text_text_embed: dict | None = None, **_kwargs: dict, ) -> VerbResult: """All the steps to transform the text units.""" - source = cast(pd.DataFrame, input.get_input()) + text_units = await runtime_storage.get("base_text_units") final_entities = cast( pd.DataFrame, get_required_input_table(input, "entities").table ) @@ -44,7 +46,7 @@ async def create_final_text_units( final_covariates = cast(pd.DataFrame, final_covariates.table) output = await create_final_text_units_flow( - source, + text_units, final_entities, final_relationships, final_covariates, diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index db07f773..78e601f1 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -16,7 +16,7 @@ 2500 ], "subworkflows": 1, - "max_runtime": 100 + "max_runtime": 300 }, "create_final_entities": { "row_range": [ diff --git a/tests/smoke/test_fixtures.py b/tests/smoke/test_fixtures.py index c5aff3d9..33b0b6a0 100644 --- a/tests/smoke/test_fixtures.py +++ b/tests/smoke/test_fixtures.py @@ -178,30 +178,26 @@ class TestIndexer: workflows == expected_workflows ), f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}" - # [OPTIONAL] Check subworkflows + # [OPTIONAL] Check runtime for workflow in expected_workflows: - if "subworkflows" in workflow_config[workflow]: - # Check number of subworkflows - subworkflows = stats["workflows"][workflow] - expected_subworkflows = workflow_config[workflow].get( - "subworkflows", None - ) - if expected_subworkflows: - assert ( - len(subworkflows) - 1 == expected_subworkflows - ), f"Expected {expected_subworkflows} subworkflows, found: {len(subworkflows) - 1} for workflow: {workflow}: [{subworkflows}]" - - # Check max runtime - max_runtime = workflow_config[workflow].get("max_runtime", None) - if max_runtime: - assert ( - stats["workflows"][workflow]["overall"] <= max_runtime - ), f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}" + # Check max runtime + max_runtime = workflow_config[workflow].get("max_runtime", None) + if max_runtime: + assert ( + stats["workflows"][workflow]["overall"] <= max_runtime + ), f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}" # Check artifacts artifact_files = os.listdir(artifacts) + # check that the number of workflows matches the number of artifacts, but: + # (1) do not count workflows with only transient output + # (2) account for the stats.json file + transient_workflows = [ + "workflow:create_base_text_units", + ] assert ( - len(artifact_files) == len(expected_workflows) + 1 + len(artifact_files) + == (len(expected_workflows) - len(transient_workflows) + 1) ), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}" for artifact in artifact_files: diff --git a/tests/verbs/test_create_base_entity_graph.py b/tests/verbs/test_create_base_entity_graph.py index 86868585..ab03d4ac 100644 --- a/tests/verbs/test_create_base_entity_graph.py +++ b/tests/verbs/test_create_base_entity_graph.py @@ -5,7 +5,7 @@ import networkx as nx import pytest from graphrag.config.enums import LLMType -from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage +from graphrag.index.run.utils import create_run_context from graphrag.index.workflows.v1.create_base_entity_graph import ( build_steps, workflow_name, @@ -55,7 +55,10 @@ async def test_create_base_entity_graph(): ]) expected = load_expected(workflow_name) - storage = MemoryPipelineStorage() + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) config = get_config_for_workflow(workflow_name) config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG @@ -68,7 +71,7 @@ async def test_create_base_entity_graph(): { "steps": steps, }, - storage=storage, + context=context, ) assert len(actual.columns) == len( @@ -88,7 +91,7 @@ async def test_create_base_entity_graph(): nodes = list(actual_graph_0.nodes(data=True)) assert nodes[0][1]["description"] == "Company_A is a test company" - assert len(storage.keys()) == 0, "Storage should be empty" + assert len(context.storage.keys()) == 0, "Storage should be empty" async def test_create_base_entity_graph_with_embeddings(): @@ -97,6 +100,11 @@ async def test_create_base_entity_graph_with_embeddings(): ]) expected = load_expected(workflow_name) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG @@ -110,6 +118,7 @@ async def test_create_base_entity_graph_with_embeddings(): { "steps": steps, }, + context=context, ) assert ( @@ -123,7 +132,10 @@ async def test_create_base_entity_graph_with_snapshots(): "workflow:create_base_text_units", ]) - storage = MemoryPipelineStorage() + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) config = get_config_for_workflow(workflow_name) @@ -140,10 +152,10 @@ async def test_create_base_entity_graph_with_snapshots(): { "steps": steps, }, - storage=storage, + context=context, ) - assert storage.keys() == [ + assert context.storage.keys() == [ "raw_extracted_entities.json", "merged_graph.graphml", "summarized_graph.graphml", @@ -157,6 +169,11 @@ async def test_create_base_entity_graph_missing_llm_throws(): "workflow:create_base_text_units", ]) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG @@ -170,4 +187,5 @@ async def test_create_base_entity_graph_missing_llm_throws(): { "steps": steps, }, + context=context, ) diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index bb3bb0ee..8fd6fee2 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +from graphrag.index.run.utils import create_run_context from graphrag.index.workflows.v1.create_base_text_units import ( build_steps, workflow_name, @@ -19,17 +20,21 @@ async def test_create_base_text_units(): input_tables = load_input_tables(inputs=[]) expected = load_expected(workflow_name) + context = create_run_context(None, None, None) + config = get_config_for_workflow(workflow_name) # test data was created with 4o, so we need to match the encoding for chunks to be identical config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base" steps = build_steps(config) - actual = await get_workflow_output( + await get_workflow_output( input_tables, { "steps": steps, }, + context, ) + actual = await context.runtime_storage.get("base_text_units") compare_outputs(actual, expected) diff --git a/tests/verbs/test_create_final_covariates.py b/tests/verbs/test_create_final_covariates.py index 4251583e..4a2aa8e6 100644 --- a/tests/verbs/test_create_final_covariates.py +++ b/tests/verbs/test_create_final_covariates.py @@ -6,6 +6,7 @@ from datashaper.errors import VerbParallelizationError from pandas.testing import assert_series_equal from graphrag.config.enums import LLMType +from graphrag.index.run.utils import create_run_context from graphrag.index.workflows.v1.create_final_covariates import ( build_steps, workflow_name, @@ -31,6 +32,11 @@ async def test_create_final_covariates(): input_tables = load_input_tables(["workflow:create_base_text_units"]) expected = load_expected(workflow_name) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) config["claim_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG @@ -42,6 +48,7 @@ async def test_create_final_covariates(): { "steps": steps, }, + context, ) input = input_tables["workflow:create_base_text_units"] @@ -81,6 +88,11 @@ async def test_create_final_covariates(): async def test_create_final_covariates_missing_llm_throws(): input_tables = load_input_tables(["workflow:create_base_text_units"]) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) del config["claim_extract"]["strategy"]["llm"] @@ -93,4 +105,5 @@ async def test_create_final_covariates_missing_llm_throws(): { "steps": steps, }, + context, ) diff --git a/tests/verbs/test_create_final_nodes.py b/tests/verbs/test_create_final_nodes.py index cd8d4dd9..6256e67e 100644 --- a/tests/verbs/test_create_final_nodes.py +++ b/tests/verbs/test_create_final_nodes.py @@ -1,7 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage +from graphrag.index.run.utils import create_run_context from graphrag.index.workflows.v1.create_final_nodes import ( build_steps, workflow_name, @@ -22,7 +22,7 @@ async def test_create_final_nodes(): ]) expected = load_expected(workflow_name) - storage = MemoryPipelineStorage() + context = create_run_context(None, None, None) config = get_config_for_workflow(workflow_name) @@ -37,12 +37,12 @@ async def test_create_final_nodes(): { "steps": steps, }, - storage=storage, + context=context, ) compare_outputs(actual, expected) - assert len(storage.keys()) == 0, "Storage should be empty" + assert len(context.storage.keys()) == 0, "Storage should be empty" async def test_create_final_nodes_with_snapshot(): @@ -51,7 +51,7 @@ async def test_create_final_nodes_with_snapshot(): ]) expected = load_expected(workflow_name) - storage = MemoryPipelineStorage() + context = create_run_context(None, None, None) config = get_config_for_workflow(workflow_name) @@ -67,11 +67,11 @@ async def test_create_final_nodes_with_snapshot(): { "steps": steps, }, - storage=storage, + context=context, ) assert actual.shape == expected.shape, "Graph dataframe shapes differ" - assert storage.keys() == [ + assert context.storage.keys() == [ "top_level_nodes.json", ], "Graph snapshot keys differ" diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 4a3bee0a..9560d06a 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -1,6 +1,7 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License +from graphrag.index.run.utils import create_run_context from graphrag.index.workflows.v1.create_final_text_units import ( build_steps, workflow_name, @@ -24,6 +25,11 @@ async def test_create_final_text_units(): ]) expected = load_expected(workflow_name) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) config["covariates_enabled"] = True @@ -36,6 +42,7 @@ async def test_create_final_text_units(): { "steps": steps, }, + context=context, ) compare_outputs(actual, expected) @@ -50,6 +57,11 @@ async def test_create_final_text_units_no_covariates(): ]) expected = load_expected(workflow_name) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) config["covariates_enabled"] = False @@ -62,6 +74,7 @@ async def test_create_final_text_units_no_covariates(): { "steps": steps, }, + context=context, ) # we're short a covariate_ids column @@ -81,6 +94,11 @@ async def test_create_final_text_units_with_embeddings(): ]) expected = load_expected(workflow_name) + context = create_run_context(None, None, None) + await context.runtime_storage.set( + "base_text_units", input_tables["workflow:create_base_text_units"] + ) + config = get_config_for_workflow(workflow_name) config["covariates_enabled"] = True @@ -96,6 +114,7 @@ async def test_create_final_text_units_with_embeddings(): { "steps": steps, }, + context=context, ) assert "text_embedding" in actual.columns diff --git a/tests/verbs/util.py b/tests/verbs/util.py index 9533bc61..389b9802 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -12,8 +12,8 @@ from graphrag.index import ( PipelineWorkflowConfig, create_pipeline_config, ) -from graphrag.index.run.utils import _create_run_context -from graphrag.index.storage.pipeline_storage import PipelineStorage +from graphrag.index.context import PipelineRunContext +from graphrag.index.run.utils import create_run_context pd.set_option("display.max_columns", None) @@ -60,7 +60,7 @@ def get_config_for_workflow(name: str) -> PipelineWorkflowConfig: async def get_workflow_output( input_tables: dict[str, pd.DataFrame], schema: dict, - storage: PipelineStorage | None = None, + context: PipelineRunContext | None = None, ) -> pd.DataFrame: """Pass in the input tables, the schema, and the output name""" @@ -70,9 +70,9 @@ async def get_workflow_output( input_tables=input_tables, ) - context = _create_run_context(storage, None, None) + run_context = context or create_run_context(None, None, None) - await workflow.run(context=context) + await workflow.run(context=run_context) # if there's only one output, it is the default here, no name required return cast(pd.DataFrame, workflow.output())