mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-04 19:59:54 +00:00
Rework workflow architecture (#1311)
* Rename pipeline_storage file * Add runtime storage option to context * Fix import * Switch to memory storage for runtime * Infra for workflow runtime storage * Migrate base_text_units to runtime storage * Fix comment * Semver * Remove whitespace * Remove subflow smoke tests and ignore transient artifacts * Remove entity graph from transient list (not yet implemented) * Increase smoke runtime allotment for create_base_entity_graph * Revert format fix * Remove noqa
This commit is contained in:
parent
ac09e0a740
commit
94f1e62e5c
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Add runtime-only storage option."
|
||||
}
|
||||
@ -34,7 +34,11 @@ class PipelineRunContext:
|
||||
|
||||
stats: PipelineRunStats
|
||||
storage: PipelineStorage
|
||||
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
|
||||
cache: PipelineCache
|
||||
"Cache instance for reading previous LLM responses."
|
||||
runtime_storage: PipelineStorage
|
||||
"Runtime only storage for pipeline verbs to use. Items written here will only live in memory during the current run."
|
||||
|
||||
|
||||
# TODO: For now, just has the same props available to it
|
||||
|
||||
@ -69,7 +69,3 @@ async def _write_workflow_stats(
|
||||
await _save_profiler_stats(
|
||||
storage, workflow.name, workflow_result.memory_profile
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"first row of %s => %s", workflow.name, workflow.output().iloc[0].to_json()
|
||||
)
|
||||
|
||||
@ -32,8 +32,8 @@ from graphrag.index.run.utils import (
|
||||
_apply_substitutions,
|
||||
_create_input,
|
||||
_create_reporter,
|
||||
_create_run_context,
|
||||
_validate_dataset,
|
||||
create_run_context,
|
||||
)
|
||||
from graphrag.index.run.workflow import (
|
||||
_create_callback_chain,
|
||||
@ -200,7 +200,7 @@ async def run_pipeline(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
context = _create_run_context(storage=storage, cache=cache, stats=None)
|
||||
context = create_run_context(storage=storage, cache=cache, stats=None)
|
||||
|
||||
progress_reporter = progress_reporter or NullProgressReporter()
|
||||
callbacks = callbacks or ConsoleWorkflowCallbacks()
|
||||
|
||||
@ -103,7 +103,7 @@ def _apply_substitutions(config: PipelineConfig, run_id: str) -> PipelineConfig:
|
||||
return config
|
||||
|
||||
|
||||
def _create_run_context(
|
||||
def create_run_context(
|
||||
storage: PipelineStorage | None,
|
||||
cache: PipelineCache | None,
|
||||
stats: PipelineRunStats | None,
|
||||
@ -113,4 +113,5 @@ def _create_run_context(
|
||||
stats=stats or PipelineRunStats(),
|
||||
cache=cache or InMemoryCache(),
|
||||
storage=storage or MemoryPipelineStorage(),
|
||||
runtime_storage=MemoryPipelineStorage(),
|
||||
)
|
||||
|
||||
@ -39,7 +39,18 @@ async def _inject_workflow_data_dependencies(
|
||||
log.info("dependencies for %s: %s", workflow.name, deps)
|
||||
for id in deps:
|
||||
workflow_id = f"workflow:{id}"
|
||||
try:
|
||||
table = await _load_table_from_storage(f"{id}.parquet", storage)
|
||||
except ValueError:
|
||||
# our workflows now allow transient tables, and we avoid putting those in primary storage
|
||||
# however, we need to keep the table in the dependency list for proper execution order
|
||||
# this allows us to catch missing table errors but emit a warning for pipeline users who may genuinely have an error (which we expect to be very rare)
|
||||
# todo: this issue will resolve itself if we remove DataShaper completely
|
||||
log.warning(
|
||||
"Dependency table %s not found in storage: it may be a runtime-only in-memory table. If you see further errors, this may be an actual problem.",
|
||||
id,
|
||||
)
|
||||
table = pd.DataFrame()
|
||||
workflow.add_table(workflow_id, table)
|
||||
|
||||
|
||||
|
||||
@ -276,6 +276,11 @@ class BlobPipelineStorage(PipelineStorage):
|
||||
self._storage_account_blob_url,
|
||||
)
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
msg = "Blob storage does yet not support listing keys."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _keyname(self, key: str) -> str:
|
||||
"""Get the key name."""
|
||||
return str(Path(self._path_prefix) / key)
|
||||
|
||||
@ -143,6 +143,10 @@ class FilePipelineStorage(PipelineStorage):
|
||||
return self
|
||||
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
return os.listdir(self._root_dir)
|
||||
|
||||
|
||||
def join_path(file_path: str, file_name: str) -> Path:
|
||||
"""Join a path and a file. Independent of the OS."""
|
||||
|
||||
@ -32,7 +32,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
|
||||
-------
|
||||
- output - The value for the given key.
|
||||
"""
|
||||
return self._storage.get(key) or await super().get(key, as_bytes, encoding)
|
||||
return self._storage.get(key)
|
||||
|
||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||
"""Set the value for the given key.
|
||||
@ -53,7 +53,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
|
||||
-------
|
||||
- output - True if the key exists in the storage, False otherwise.
|
||||
"""
|
||||
return key in self._storage or await super().has(key)
|
||||
return key in self._storage
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
"""Delete the given key from the storage.
|
||||
@ -69,7 +69,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
|
||||
|
||||
def child(self, name: str | None) -> "PipelineStorage":
|
||||
"""Create a child storage instance."""
|
||||
return self
|
||||
return MemoryPipelineStorage()
|
||||
|
||||
def keys(self) -> list[str]:
|
||||
"""Return the keys in the storage."""
|
||||
|
||||
@ -76,3 +76,7 @@ class PipelineStorage(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def child(self, name: str | None) -> "PipelineStorage":
|
||||
"""Create a child storage instance."""
|
||||
|
||||
@abstractmethod
|
||||
def keys(self) -> list[str]:
|
||||
"""List all keys in the storage."""
|
||||
|
||||
@ -5,12 +5,10 @@
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
@ -27,10 +25,10 @@ from graphrag.index.storage import PipelineStorage
|
||||
treats_input_tables_as_immutable=True,
|
||||
)
|
||||
async def create_base_entity_graph(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
runtime_storage: PipelineStorage,
|
||||
text_column: str,
|
||||
id_column: str,
|
||||
clustering_strategy: dict[str, Any],
|
||||
@ -48,10 +46,10 @@ async def create_base_entity_graph(
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to create the base entity graph."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
text_units = await runtime_storage.get("base_text_units")
|
||||
|
||||
output = await create_base_entity_graph_flow(
|
||||
source,
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
storage,
|
||||
|
||||
@ -17,12 +17,14 @@ from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
from graphrag.index.flows.create_base_text_units import (
|
||||
create_base_text_units as create_base_text_units_flow,
|
||||
)
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
|
||||
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
|
||||
def create_base_text_units(
|
||||
async def create_base_text_units(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
runtime_storage: PipelineStorage,
|
||||
chunk_column_name: str,
|
||||
n_tokens_column_name: str,
|
||||
chunk_by_columns: list[str],
|
||||
@ -41,9 +43,11 @@ def create_base_text_units(
|
||||
chunk_strategy=chunk_strategy,
|
||||
)
|
||||
|
||||
await runtime_storage.set("base_text_units", output)
|
||||
|
||||
return create_verb_result(
|
||||
cast(
|
||||
Table,
|
||||
output,
|
||||
pd.DataFrame(),
|
||||
)
|
||||
)
|
||||
|
||||
@ -5,12 +5,10 @@
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
from datashaper import (
|
||||
AsyncType,
|
||||
Table,
|
||||
VerbCallbacks,
|
||||
VerbInput,
|
||||
verb,
|
||||
)
|
||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||
@ -19,13 +17,14 @@ from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_final_covariates import (
|
||||
create_final_covariates as create_final_covariates_flow,
|
||||
)
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
|
||||
|
||||
@verb(name="create_final_covariates", treats_input_tables_as_immutable=True)
|
||||
async def create_final_covariates(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
runtime_storage: PipelineStorage,
|
||||
column: str,
|
||||
covariate_type: str,
|
||||
extraction_strategy: dict[str, Any] | None,
|
||||
@ -35,10 +34,10 @@ async def create_final_covariates(
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to extract and format covariates."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
text_units = await runtime_storage.get("base_text_units")
|
||||
|
||||
output = await create_final_covariates_flow(
|
||||
source,
|
||||
text_units,
|
||||
callbacks,
|
||||
cache,
|
||||
column,
|
||||
|
||||
@ -19,6 +19,7 @@ from graphrag.index.cache import PipelineCache
|
||||
from graphrag.index.flows.create_final_text_units import (
|
||||
create_final_text_units as create_final_text_units_flow,
|
||||
)
|
||||
from graphrag.index.storage import PipelineStorage
|
||||
from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table
|
||||
|
||||
|
||||
@ -27,11 +28,12 @@ async def create_final_text_units(
|
||||
input: VerbInput,
|
||||
callbacks: VerbCallbacks,
|
||||
cache: PipelineCache,
|
||||
runtime_storage: PipelineStorage,
|
||||
text_text_embed: dict | None = None,
|
||||
**_kwargs: dict,
|
||||
) -> VerbResult:
|
||||
"""All the steps to transform the text units."""
|
||||
source = cast(pd.DataFrame, input.get_input())
|
||||
text_units = await runtime_storage.get("base_text_units")
|
||||
final_entities = cast(
|
||||
pd.DataFrame, get_required_input_table(input, "entities").table
|
||||
)
|
||||
@ -44,7 +46,7 @@ async def create_final_text_units(
|
||||
final_covariates = cast(pd.DataFrame, final_covariates.table)
|
||||
|
||||
output = await create_final_text_units_flow(
|
||||
source,
|
||||
text_units,
|
||||
final_entities,
|
||||
final_relationships,
|
||||
final_covariates,
|
||||
|
||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -16,7 +16,7 @@
|
||||
2500
|
||||
],
|
||||
"subworkflows": 1,
|
||||
"max_runtime": 100
|
||||
"max_runtime": 300
|
||||
},
|
||||
"create_final_entities": {
|
||||
"row_range": [
|
||||
|
||||
@ -178,19 +178,8 @@ class TestIndexer:
|
||||
workflows == expected_workflows
|
||||
), f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}"
|
||||
|
||||
# [OPTIONAL] Check subworkflows
|
||||
# [OPTIONAL] Check runtime
|
||||
for workflow in expected_workflows:
|
||||
if "subworkflows" in workflow_config[workflow]:
|
||||
# Check number of subworkflows
|
||||
subworkflows = stats["workflows"][workflow]
|
||||
expected_subworkflows = workflow_config[workflow].get(
|
||||
"subworkflows", None
|
||||
)
|
||||
if expected_subworkflows:
|
||||
assert (
|
||||
len(subworkflows) - 1 == expected_subworkflows
|
||||
), f"Expected {expected_subworkflows} subworkflows, found: {len(subworkflows) - 1} for workflow: {workflow}: [{subworkflows}]"
|
||||
|
||||
# Check max runtime
|
||||
max_runtime = workflow_config[workflow].get("max_runtime", None)
|
||||
if max_runtime:
|
||||
@ -200,8 +189,15 @@ class TestIndexer:
|
||||
|
||||
# Check artifacts
|
||||
artifact_files = os.listdir(artifacts)
|
||||
# check that the number of workflows matches the number of artifacts, but:
|
||||
# (1) do not count workflows with only transient output
|
||||
# (2) account for the stats.json file
|
||||
transient_workflows = [
|
||||
"workflow:create_base_text_units",
|
||||
]
|
||||
assert (
|
||||
len(artifact_files) == len(expected_workflows) + 1
|
||||
len(artifact_files)
|
||||
== (len(expected_workflows) - len(transient_workflows) + 1)
|
||||
), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}"
|
||||
|
||||
for artifact in artifact_files:
|
||||
|
||||
@ -5,7 +5,7 @@ import networkx as nx
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.workflows.v1.create_base_entity_graph import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
@ -55,7 +55,10 @@ async def test_create_base_entity_graph():
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
@ -68,7 +71,7 @@ async def test_create_base_entity_graph():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert len(actual.columns) == len(
|
||||
@ -88,7 +91,7 @@ async def test_create_base_entity_graph():
|
||||
nodes = list(actual_graph_0.nodes(data=True))
|
||||
assert nodes[0][1]["description"] == "Company_A is a test company"
|
||||
|
||||
assert len(storage.keys()) == 0, "Storage should be empty"
|
||||
assert len(context.storage.keys()) == 0, "Storage should be empty"
|
||||
|
||||
|
||||
async def test_create_base_entity_graph_with_embeddings():
|
||||
@ -97,6 +100,11 @@ async def test_create_base_entity_graph_with_embeddings():
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
@ -110,6 +118,7 @@ async def test_create_base_entity_graph_with_embeddings():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert (
|
||||
@ -123,7 +132,10 @@ async def test_create_base_entity_graph_with_snapshots():
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
@ -140,10 +152,10 @@ async def test_create_base_entity_graph_with_snapshots():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert storage.keys() == [
|
||||
assert context.storage.keys() == [
|
||||
"raw_extracted_entities.json",
|
||||
"merged_graph.graphml",
|
||||
"summarized_graph.graphml",
|
||||
@ -157,6 +169,11 @@ async def test_create_base_entity_graph_missing_llm_throws():
|
||||
"workflow:create_base_text_units",
|
||||
])
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||
@ -170,4 +187,5 @@ async def test_create_base_entity_graph_missing_llm_throws():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.workflows.v1.create_base_text_units import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
@ -19,17 +20,21 @@ async def test_create_base_text_units():
|
||||
input_tables = load_input_tables(inputs=[])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
# test data was created with 4o, so we need to match the encoding for chunks to be identical
|
||||
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
|
||||
|
||||
steps = build_steps(config)
|
||||
|
||||
actual = await get_workflow_output(
|
||||
await get_workflow_output(
|
||||
input_tables,
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context,
|
||||
)
|
||||
|
||||
actual = await context.runtime_storage.get("base_text_units")
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
@ -6,6 +6,7 @@ from datashaper.errors import VerbParallelizationError
|
||||
from pandas.testing import assert_series_equal
|
||||
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.workflows.v1.create_final_covariates import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
@ -31,6 +32,11 @@ async def test_create_final_covariates():
|
||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["claim_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
|
||||
@ -42,6 +48,7 @@ async def test_create_final_covariates():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context,
|
||||
)
|
||||
|
||||
input = input_tables["workflow:create_base_text_units"]
|
||||
@ -81,6 +88,11 @@ async def test_create_final_covariates():
|
||||
async def test_create_final_covariates_missing_llm_throws():
|
||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
del config["claim_extract"]["strategy"]["llm"]
|
||||
@ -93,4 +105,5 @@ async def test_create_final_covariates_missing_llm_throws():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context,
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.workflows.v1.create_final_nodes import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
@ -22,7 +22,7 @@ async def test_create_final_nodes():
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
context = create_run_context(None, None, None)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
@ -37,12 +37,12 @@ async def test_create_final_nodes():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
context=context,
|
||||
)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
assert len(storage.keys()) == 0, "Storage should be empty"
|
||||
assert len(context.storage.keys()) == 0, "Storage should be empty"
|
||||
|
||||
|
||||
async def test_create_final_nodes_with_snapshot():
|
||||
@ -51,7 +51,7 @@ async def test_create_final_nodes_with_snapshot():
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
storage = MemoryPipelineStorage()
|
||||
context = create_run_context(None, None, None)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
@ -67,11 +67,11 @@ async def test_create_final_nodes_with_snapshot():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
storage=storage,
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||
|
||||
assert storage.keys() == [
|
||||
assert context.storage.keys() == [
|
||||
"top_level_nodes.json",
|
||||
], "Graph snapshot keys differ"
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.workflows.v1.create_final_text_units import (
|
||||
build_steps,
|
||||
workflow_name,
|
||||
@ -24,6 +25,11 @@ async def test_create_final_text_units():
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["covariates_enabled"] = True
|
||||
@ -36,6 +42,7 @@ async def test_create_final_text_units():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
@ -50,6 +57,11 @@ async def test_create_final_text_units_no_covariates():
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["covariates_enabled"] = False
|
||||
@ -62,6 +74,7 @@ async def test_create_final_text_units_no_covariates():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
# we're short a covariate_ids column
|
||||
@ -81,6 +94,11 @@ async def test_create_final_text_units_with_embeddings():
|
||||
])
|
||||
expected = load_expected(workflow_name)
|
||||
|
||||
context = create_run_context(None, None, None)
|
||||
await context.runtime_storage.set(
|
||||
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||
)
|
||||
|
||||
config = get_config_for_workflow(workflow_name)
|
||||
|
||||
config["covariates_enabled"] = True
|
||||
@ -96,6 +114,7 @@ async def test_create_final_text_units_with_embeddings():
|
||||
{
|
||||
"steps": steps,
|
||||
},
|
||||
context=context,
|
||||
)
|
||||
|
||||
assert "text_embedding" in actual.columns
|
||||
|
||||
@ -12,8 +12,8 @@ from graphrag.index import (
|
||||
PipelineWorkflowConfig,
|
||||
create_pipeline_config,
|
||||
)
|
||||
from graphrag.index.run.utils import _create_run_context
|
||||
from graphrag.index.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
|
||||
pd.set_option("display.max_columns", None)
|
||||
|
||||
@ -60,7 +60,7 @@ def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
|
||||
async def get_workflow_output(
|
||||
input_tables: dict[str, pd.DataFrame],
|
||||
schema: dict,
|
||||
storage: PipelineStorage | None = None,
|
||||
context: PipelineRunContext | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Pass in the input tables, the schema, and the output name"""
|
||||
|
||||
@ -70,9 +70,9 @@ async def get_workflow_output(
|
||||
input_tables=input_tables,
|
||||
)
|
||||
|
||||
context = _create_run_context(storage, None, None)
|
||||
run_context = context or create_run_context(None, None, None)
|
||||
|
||||
await workflow.run(context=context)
|
||||
await workflow.run(context=run_context)
|
||||
|
||||
# if there's only one output, it is the default here, no name required
|
||||
return cast(pd.DataFrame, workflow.output())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user