Rework workflow architecture (#1311)

* Rename pipeline_storage file

* Add runtime storage option to context

* Fix import

* Switch to memory storage for runtime

* Infra for workflow runtime storage

* Migrate base_text_units to runtime storage

* Fix comment

* Semver

* Remove whitespace

* Remove subflow smoke tests and ignore transient artifacts

* Remove entity graph from transient list (not yet implemented)

* Increase smoke runtime allotment for create_base_entity_graph

* Revert format fix

* Remove noqa
This commit is contained in:
Nathan Evans 2024-10-24 10:20:03 -07:00 committed by GitHub
parent ac09e0a740
commit 94f1e62e5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 148 additions and 65 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Add runtime-only storage option."
}

View File

@ -34,7 +34,11 @@ class PipelineRunContext:
stats: PipelineRunStats
storage: PipelineStorage
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
cache: PipelineCache
"Cache instance for reading previous LLM responses."
runtime_storage: PipelineStorage
"Runtime only storage for pipeline verbs to use. Items written here will only live in memory during the current run."
# TODO: For now, just has the same props available to it

View File

@ -69,7 +69,3 @@ async def _write_workflow_stats(
await _save_profiler_stats(
storage, workflow.name, workflow_result.memory_profile
)
log.debug(
"first row of %s => %s", workflow.name, workflow.output().iloc[0].to_json()
)

View File

@ -32,8 +32,8 @@ from graphrag.index.run.utils import (
_apply_substitutions,
_create_input,
_create_reporter,
_create_run_context,
_validate_dataset,
create_run_context,
)
from graphrag.index.run.workflow import (
_create_callback_chain,
@ -200,7 +200,7 @@ async def run_pipeline(
"""
start_time = time.time()
context = _create_run_context(storage=storage, cache=cache, stats=None)
context = create_run_context(storage=storage, cache=cache, stats=None)
progress_reporter = progress_reporter or NullProgressReporter()
callbacks = callbacks or ConsoleWorkflowCallbacks()

View File

@ -103,7 +103,7 @@ def _apply_substitutions(config: PipelineConfig, run_id: str) -> PipelineConfig:
return config
def _create_run_context(
def create_run_context(
storage: PipelineStorage | None,
cache: PipelineCache | None,
stats: PipelineRunStats | None,
@ -113,4 +113,5 @@ def _create_run_context(
stats=stats or PipelineRunStats(),
cache=cache or InMemoryCache(),
storage=storage or MemoryPipelineStorage(),
runtime_storage=MemoryPipelineStorage(),
)

View File

@ -39,7 +39,18 @@ async def _inject_workflow_data_dependencies(
log.info("dependencies for %s: %s", workflow.name, deps)
for id in deps:
workflow_id = f"workflow:{id}"
table = await _load_table_from_storage(f"{id}.parquet", storage)
try:
table = await _load_table_from_storage(f"{id}.parquet", storage)
except ValueError:
# our workflows now allow transient tables, and we avoid putting those in primary storage
# however, we need to keep the table in the dependency list for proper execution order
# this allows us to catch missing table errors but emit a warning for pipeline users who may genuinely have an error (which we expect to be very rare)
# todo: this issue will resolve itself if we remove DataShaper completely
log.warning(
"Dependency table %s not found in storage: it may be a runtime-only in-memory table. If you see further errors, this may be an actual problem.",
id,
)
table = pd.DataFrame()
workflow.add_table(workflow_id, table)

View File

@ -276,6 +276,11 @@ class BlobPipelineStorage(PipelineStorage):
self._storage_account_blob_url,
)
def keys(self) -> list[str]:
"""Return the keys in the storage."""
msg = "Blob storage does yet not support listing keys."
raise NotImplementedError(msg)
def _keyname(self, key: str) -> str:
"""Get the key name."""
return str(Path(self._path_prefix) / key)

View File

@ -143,6 +143,10 @@ class FilePipelineStorage(PipelineStorage):
return self
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
def keys(self) -> list[str]:
"""Return the keys in the storage."""
return os.listdir(self._root_dir)
def join_path(file_path: str, file_name: str) -> Path:
"""Join a path and a file. Independent of the OS."""

View File

@ -32,7 +32,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
-------
- output - The value for the given key.
"""
return self._storage.get(key) or await super().get(key, as_bytes, encoding)
return self._storage.get(key)
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
"""Set the value for the given key.
@ -53,7 +53,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
-------
- output - True if the key exists in the storage, False otherwise.
"""
return key in self._storage or await super().has(key)
return key in self._storage
async def delete(self, key: str) -> None:
"""Delete the given key from the storage.
@ -69,7 +69,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
def child(self, name: str | None) -> "PipelineStorage":
"""Create a child storage instance."""
return self
return MemoryPipelineStorage()
def keys(self) -> list[str]:
"""Return the keys in the storage."""

View File

@ -76,3 +76,7 @@ class PipelineStorage(metaclass=ABCMeta):
@abstractmethod
def child(self, name: str | None) -> "PipelineStorage":
"""Create a child storage instance."""
@abstractmethod
def keys(self) -> list[str]:
"""List all keys in the storage."""

View File

@ -5,12 +5,10 @@
from typing import Any, cast
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
@ -27,10 +25,10 @@ from graphrag.index.storage import PipelineStorage
treats_input_tables_as_immutable=True,
)
async def create_base_entity_graph(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
runtime_storage: PipelineStorage,
text_column: str,
id_column: str,
clustering_strategy: dict[str, Any],
@ -48,10 +46,10 @@ async def create_base_entity_graph(
**_kwargs: dict,
) -> VerbResult:
"""All the steps to create the base entity graph."""
source = cast(pd.DataFrame, input.get_input())
text_units = await runtime_storage.get("base_text_units")
output = await create_base_entity_graph_flow(
source,
text_units,
callbacks,
cache,
storage,

View File

@ -17,12 +17,14 @@ from datashaper.table_store.types import VerbResult, create_verb_result
from graphrag.index.flows.create_base_text_units import (
create_base_text_units as create_base_text_units_flow,
)
from graphrag.index.storage import PipelineStorage
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
def create_base_text_units(
async def create_base_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
runtime_storage: PipelineStorage,
chunk_column_name: str,
n_tokens_column_name: str,
chunk_by_columns: list[str],
@ -41,9 +43,11 @@ def create_base_text_units(
chunk_strategy=chunk_strategy,
)
await runtime_storage.set("base_text_units", output)
return create_verb_result(
cast(
Table,
output,
pd.DataFrame(),
)
)

View File

@ -5,12 +5,10 @@
from typing import Any, cast
import pandas as pd
from datashaper import (
AsyncType,
Table,
VerbCallbacks,
VerbInput,
verb,
)
from datashaper.table_store.types import VerbResult, create_verb_result
@ -19,13 +17,14 @@ from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_final_covariates import (
create_final_covariates as create_final_covariates_flow,
)
from graphrag.index.storage import PipelineStorage
@verb(name="create_final_covariates", treats_input_tables_as_immutable=True)
async def create_final_covariates(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
runtime_storage: PipelineStorage,
column: str,
covariate_type: str,
extraction_strategy: dict[str, Any] | None,
@ -35,10 +34,10 @@ async def create_final_covariates(
**_kwargs: dict,
) -> VerbResult:
"""All the steps to extract and format covariates."""
source = cast(pd.DataFrame, input.get_input())
text_units = await runtime_storage.get("base_text_units")
output = await create_final_covariates_flow(
source,
text_units,
callbacks,
cache,
column,

View File

@ -19,6 +19,7 @@ from graphrag.index.cache import PipelineCache
from graphrag.index.flows.create_final_text_units import (
create_final_text_units as create_final_text_units_flow,
)
from graphrag.index.storage import PipelineStorage
from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table
@ -27,11 +28,12 @@ async def create_final_text_units(
input: VerbInput,
callbacks: VerbCallbacks,
cache: PipelineCache,
runtime_storage: PipelineStorage,
text_text_embed: dict | None = None,
**_kwargs: dict,
) -> VerbResult:
"""All the steps to transform the text units."""
source = cast(pd.DataFrame, input.get_input())
text_units = await runtime_storage.get("base_text_units")
final_entities = cast(
pd.DataFrame, get_required_input_table(input, "entities").table
)
@ -44,7 +46,7 @@ async def create_final_text_units(
final_covariates = cast(pd.DataFrame, final_covariates.table)
output = await create_final_text_units_flow(
source,
text_units,
final_entities,
final_relationships,
final_covariates,

View File

@ -16,7 +16,7 @@
2500
],
"subworkflows": 1,
"max_runtime": 100
"max_runtime": 300
},
"create_final_entities": {
"row_range": [

View File

@ -178,30 +178,26 @@ class TestIndexer:
workflows == expected_workflows
), f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}"
# [OPTIONAL] Check subworkflows
# [OPTIONAL] Check runtime
for workflow in expected_workflows:
if "subworkflows" in workflow_config[workflow]:
# Check number of subworkflows
subworkflows = stats["workflows"][workflow]
expected_subworkflows = workflow_config[workflow].get(
"subworkflows", None
)
if expected_subworkflows:
assert (
len(subworkflows) - 1 == expected_subworkflows
), f"Expected {expected_subworkflows} subworkflows, found: {len(subworkflows) - 1} for workflow: {workflow}: [{subworkflows}]"
# Check max runtime
max_runtime = workflow_config[workflow].get("max_runtime", None)
if max_runtime:
assert (
stats["workflows"][workflow]["overall"] <= max_runtime
), f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}"
# Check max runtime
max_runtime = workflow_config[workflow].get("max_runtime", None)
if max_runtime:
assert (
stats["workflows"][workflow]["overall"] <= max_runtime
), f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}"
# Check artifacts
artifact_files = os.listdir(artifacts)
# check that the number of workflows matches the number of artifacts, but:
# (1) do not count workflows with only transient output
# (2) account for the stats.json file
transient_workflows = [
"workflow:create_base_text_units",
]
assert (
len(artifact_files) == len(expected_workflows) + 1
len(artifact_files)
== (len(expected_workflows) - len(transient_workflows) + 1)
), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}"
for artifact in artifact_files:

View File

@ -5,7 +5,7 @@ import networkx as nx
import pytest
from graphrag.config.enums import LLMType
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.create_base_entity_graph import (
build_steps,
workflow_name,
@ -55,7 +55,10 @@ async def test_create_base_entity_graph():
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
@ -68,7 +71,7 @@ async def test_create_base_entity_graph():
{
"steps": steps,
},
storage=storage,
context=context,
)
assert len(actual.columns) == len(
@ -88,7 +91,7 @@ async def test_create_base_entity_graph():
nodes = list(actual_graph_0.nodes(data=True))
assert nodes[0][1]["description"] == "Company_A is a test company"
assert len(storage.keys()) == 0, "Storage should be empty"
assert len(context.storage.keys()) == 0, "Storage should be empty"
async def test_create_base_entity_graph_with_embeddings():
@ -97,6 +100,11 @@ async def test_create_base_entity_graph_with_embeddings():
])
expected = load_expected(workflow_name)
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
@ -110,6 +118,7 @@ async def test_create_base_entity_graph_with_embeddings():
{
"steps": steps,
},
context=context,
)
assert (
@ -123,7 +132,10 @@ async def test_create_base_entity_graph_with_snapshots():
"workflow:create_base_text_units",
])
storage = MemoryPipelineStorage()
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
@ -140,10 +152,10 @@ async def test_create_base_entity_graph_with_snapshots():
{
"steps": steps,
},
storage=storage,
context=context,
)
assert storage.keys() == [
assert context.storage.keys() == [
"raw_extracted_entities.json",
"merged_graph.graphml",
"summarized_graph.graphml",
@ -157,6 +169,11 @@ async def test_create_base_entity_graph_missing_llm_throws():
"workflow:create_base_text_units",
])
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
@ -170,4 +187,5 @@ async def test_create_base_entity_graph_missing_llm_throws():
{
"steps": steps,
},
context=context,
)

View File

@ -1,6 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.create_base_text_units import (
build_steps,
workflow_name,
@ -19,17 +20,21 @@ async def test_create_base_text_units():
input_tables = load_input_tables(inputs=[])
expected = load_expected(workflow_name)
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
# test data was created with 4o, so we need to match the encoding for chunks to be identical
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
steps = build_steps(config)
actual = await get_workflow_output(
await get_workflow_output(
input_tables,
{
"steps": steps,
},
context,
)
actual = await context.runtime_storage.get("base_text_units")
compare_outputs(actual, expected)

View File

@ -6,6 +6,7 @@ from datashaper.errors import VerbParallelizationError
from pandas.testing import assert_series_equal
from graphrag.config.enums import LLMType
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.create_final_covariates import (
build_steps,
workflow_name,
@ -31,6 +32,11 @@ async def test_create_final_covariates():
input_tables = load_input_tables(["workflow:create_base_text_units"])
expected = load_expected(workflow_name)
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["claim_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
@ -42,6 +48,7 @@ async def test_create_final_covariates():
{
"steps": steps,
},
context,
)
input = input_tables["workflow:create_base_text_units"]
@ -81,6 +88,11 @@ async def test_create_final_covariates():
async def test_create_final_covariates_missing_llm_throws():
input_tables = load_input_tables(["workflow:create_base_text_units"])
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
del config["claim_extract"]["strategy"]["llm"]
@ -93,4 +105,5 @@ async def test_create_final_covariates_missing_llm_throws():
{
"steps": steps,
},
context,
)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.create_final_nodes import (
build_steps,
workflow_name,
@ -22,7 +22,7 @@ async def test_create_final_nodes():
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
@ -37,12 +37,12 @@ async def test_create_final_nodes():
{
"steps": steps,
},
storage=storage,
context=context,
)
compare_outputs(actual, expected)
assert len(storage.keys()) == 0, "Storage should be empty"
assert len(context.storage.keys()) == 0, "Storage should be empty"
async def test_create_final_nodes_with_snapshot():
@ -51,7 +51,7 @@ async def test_create_final_nodes_with_snapshot():
])
expected = load_expected(workflow_name)
storage = MemoryPipelineStorage()
context = create_run_context(None, None, None)
config = get_config_for_workflow(workflow_name)
@ -67,11 +67,11 @@ async def test_create_final_nodes_with_snapshot():
{
"steps": steps,
},
storage=storage,
context=context,
)
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
assert storage.keys() == [
assert context.storage.keys() == [
"top_level_nodes.json",
], "Graph snapshot keys differ"

View File

@ -1,6 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.index.run.utils import create_run_context
from graphrag.index.workflows.v1.create_final_text_units import (
build_steps,
workflow_name,
@ -24,6 +25,11 @@ async def test_create_final_text_units():
])
expected = load_expected(workflow_name)
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["covariates_enabled"] = True
@ -36,6 +42,7 @@ async def test_create_final_text_units():
{
"steps": steps,
},
context=context,
)
compare_outputs(actual, expected)
@ -50,6 +57,11 @@ async def test_create_final_text_units_no_covariates():
])
expected = load_expected(workflow_name)
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["covariates_enabled"] = False
@ -62,6 +74,7 @@ async def test_create_final_text_units_no_covariates():
{
"steps": steps,
},
context=context,
)
# we're short a covariate_ids column
@ -81,6 +94,11 @@ async def test_create_final_text_units_with_embeddings():
])
expected = load_expected(workflow_name)
context = create_run_context(None, None, None)
await context.runtime_storage.set(
"base_text_units", input_tables["workflow:create_base_text_units"]
)
config = get_config_for_workflow(workflow_name)
config["covariates_enabled"] = True
@ -96,6 +114,7 @@ async def test_create_final_text_units_with_embeddings():
{
"steps": steps,
},
context=context,
)
assert "text_embedding" in actual.columns

View File

@ -12,8 +12,8 @@ from graphrag.index import (
PipelineWorkflowConfig,
create_pipeline_config,
)
from graphrag.index.run.utils import _create_run_context
from graphrag.index.storage.pipeline_storage import PipelineStorage
from graphrag.index.context import PipelineRunContext
from graphrag.index.run.utils import create_run_context
pd.set_option("display.max_columns", None)
@ -60,7 +60,7 @@ def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
async def get_workflow_output(
input_tables: dict[str, pd.DataFrame],
schema: dict,
storage: PipelineStorage | None = None,
context: PipelineRunContext | None = None,
) -> pd.DataFrame:
"""Pass in the input tables, the schema, and the output name"""
@ -70,9 +70,9 @@ async def get_workflow_output(
input_tables=input_tables,
)
context = _create_run_context(storage, None, None)
run_context = context or create_run_context(None, None, None)
await workflow.run(context=context)
await workflow.run(context=run_context)
# if there's only one output, it is the default here, no name required
return cast(pd.DataFrame, workflow.output())