mirror of
https://github.com/microsoft/graphrag.git
synced 2025-11-05 20:34:21 +00:00
Rework workflow architecture (#1311)
* Rename pipeline_storage file * Add runtime storage option to context * Fix import * Switch to memory storage for runtime * Infra for workflow runtime storage * Migrate base_text_units to runtime storage * Fix comment * Semver * Remove whitespace * Remove subflow smoke tests and ignore transient artifacts * Remove entity graph from transient list (not yet implemented) * Increase smoke runtime allotment for create_base_entity_graph * Revert format fix * Remove noqa
This commit is contained in:
parent
ac09e0a740
commit
94f1e62e5c
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Add runtime-only storage option."
|
||||||
|
}
|
||||||
@ -34,7 +34,11 @@ class PipelineRunContext:
|
|||||||
|
|
||||||
stats: PipelineRunStats
|
stats: PipelineRunStats
|
||||||
storage: PipelineStorage
|
storage: PipelineStorage
|
||||||
|
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
|
||||||
cache: PipelineCache
|
cache: PipelineCache
|
||||||
|
"Cache instance for reading previous LLM responses."
|
||||||
|
runtime_storage: PipelineStorage
|
||||||
|
"Runtime only storage for pipeline verbs to use. Items written here will only live in memory during the current run."
|
||||||
|
|
||||||
|
|
||||||
# TODO: For now, just has the same props available to it
|
# TODO: For now, just has the same props available to it
|
||||||
|
|||||||
@ -69,7 +69,3 @@ async def _write_workflow_stats(
|
|||||||
await _save_profiler_stats(
|
await _save_profiler_stats(
|
||||||
storage, workflow.name, workflow_result.memory_profile
|
storage, workflow.name, workflow_result.memory_profile
|
||||||
)
|
)
|
||||||
|
|
||||||
log.debug(
|
|
||||||
"first row of %s => %s", workflow.name, workflow.output().iloc[0].to_json()
|
|
||||||
)
|
|
||||||
|
|||||||
@ -32,8 +32,8 @@ from graphrag.index.run.utils import (
|
|||||||
_apply_substitutions,
|
_apply_substitutions,
|
||||||
_create_input,
|
_create_input,
|
||||||
_create_reporter,
|
_create_reporter,
|
||||||
_create_run_context,
|
|
||||||
_validate_dataset,
|
_validate_dataset,
|
||||||
|
create_run_context,
|
||||||
)
|
)
|
||||||
from graphrag.index.run.workflow import (
|
from graphrag.index.run.workflow import (
|
||||||
_create_callback_chain,
|
_create_callback_chain,
|
||||||
@ -200,7 +200,7 @@ async def run_pipeline(
|
|||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
context = _create_run_context(storage=storage, cache=cache, stats=None)
|
context = create_run_context(storage=storage, cache=cache, stats=None)
|
||||||
|
|
||||||
progress_reporter = progress_reporter or NullProgressReporter()
|
progress_reporter = progress_reporter or NullProgressReporter()
|
||||||
callbacks = callbacks or ConsoleWorkflowCallbacks()
|
callbacks = callbacks or ConsoleWorkflowCallbacks()
|
||||||
|
|||||||
@ -103,7 +103,7 @@ def _apply_substitutions(config: PipelineConfig, run_id: str) -> PipelineConfig:
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def _create_run_context(
|
def create_run_context(
|
||||||
storage: PipelineStorage | None,
|
storage: PipelineStorage | None,
|
||||||
cache: PipelineCache | None,
|
cache: PipelineCache | None,
|
||||||
stats: PipelineRunStats | None,
|
stats: PipelineRunStats | None,
|
||||||
@ -113,4 +113,5 @@ def _create_run_context(
|
|||||||
stats=stats or PipelineRunStats(),
|
stats=stats or PipelineRunStats(),
|
||||||
cache=cache or InMemoryCache(),
|
cache=cache or InMemoryCache(),
|
||||||
storage=storage or MemoryPipelineStorage(),
|
storage=storage or MemoryPipelineStorage(),
|
||||||
|
runtime_storage=MemoryPipelineStorage(),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -39,7 +39,18 @@ async def _inject_workflow_data_dependencies(
|
|||||||
log.info("dependencies for %s: %s", workflow.name, deps)
|
log.info("dependencies for %s: %s", workflow.name, deps)
|
||||||
for id in deps:
|
for id in deps:
|
||||||
workflow_id = f"workflow:{id}"
|
workflow_id = f"workflow:{id}"
|
||||||
|
try:
|
||||||
table = await _load_table_from_storage(f"{id}.parquet", storage)
|
table = await _load_table_from_storage(f"{id}.parquet", storage)
|
||||||
|
except ValueError:
|
||||||
|
# our workflows now allow transient tables, and we avoid putting those in primary storage
|
||||||
|
# however, we need to keep the table in the dependency list for proper execution order
|
||||||
|
# this allows us to catch missing table errors but emit a warning for pipeline users who may genuinely have an error (which we expect to be very rare)
|
||||||
|
# todo: this issue will resolve itself if we remove DataShaper completely
|
||||||
|
log.warning(
|
||||||
|
"Dependency table %s not found in storage: it may be a runtime-only in-memory table. If you see further errors, this may be an actual problem.",
|
||||||
|
id,
|
||||||
|
)
|
||||||
|
table = pd.DataFrame()
|
||||||
workflow.add_table(workflow_id, table)
|
workflow.add_table(workflow_id, table)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -276,6 +276,11 @@ class BlobPipelineStorage(PipelineStorage):
|
|||||||
self._storage_account_blob_url,
|
self._storage_account_blob_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def keys(self) -> list[str]:
|
||||||
|
"""Return the keys in the storage."""
|
||||||
|
msg = "Blob storage does yet not support listing keys."
|
||||||
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
def _keyname(self, key: str) -> str:
|
def _keyname(self, key: str) -> str:
|
||||||
"""Get the key name."""
|
"""Get the key name."""
|
||||||
return str(Path(self._path_prefix) / key)
|
return str(Path(self._path_prefix) / key)
|
||||||
|
|||||||
@ -143,6 +143,10 @@ class FilePipelineStorage(PipelineStorage):
|
|||||||
return self
|
return self
|
||||||
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
|
return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
|
||||||
|
|
||||||
|
def keys(self) -> list[str]:
|
||||||
|
"""Return the keys in the storage."""
|
||||||
|
return os.listdir(self._root_dir)
|
||||||
|
|
||||||
|
|
||||||
def join_path(file_path: str, file_name: str) -> Path:
|
def join_path(file_path: str, file_name: str) -> Path:
|
||||||
"""Join a path and a file. Independent of the OS."""
|
"""Join a path and a file. Independent of the OS."""
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
|
|||||||
-------
|
-------
|
||||||
- output - The value for the given key.
|
- output - The value for the given key.
|
||||||
"""
|
"""
|
||||||
return self._storage.get(key) or await super().get(key, as_bytes, encoding)
|
return self._storage.get(key)
|
||||||
|
|
||||||
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
|
||||||
"""Set the value for the given key.
|
"""Set the value for the given key.
|
||||||
@ -53,7 +53,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
|
|||||||
-------
|
-------
|
||||||
- output - True if the key exists in the storage, False otherwise.
|
- output - True if the key exists in the storage, False otherwise.
|
||||||
"""
|
"""
|
||||||
return key in self._storage or await super().has(key)
|
return key in self._storage
|
||||||
|
|
||||||
async def delete(self, key: str) -> None:
|
async def delete(self, key: str) -> None:
|
||||||
"""Delete the given key from the storage.
|
"""Delete the given key from the storage.
|
||||||
@ -69,7 +69,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
|
|||||||
|
|
||||||
def child(self, name: str | None) -> "PipelineStorage":
|
def child(self, name: str | None) -> "PipelineStorage":
|
||||||
"""Create a child storage instance."""
|
"""Create a child storage instance."""
|
||||||
return self
|
return MemoryPipelineStorage()
|
||||||
|
|
||||||
def keys(self) -> list[str]:
|
def keys(self) -> list[str]:
|
||||||
"""Return the keys in the storage."""
|
"""Return the keys in the storage."""
|
||||||
|
|||||||
@ -76,3 +76,7 @@ class PipelineStorage(metaclass=ABCMeta):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def child(self, name: str | None) -> "PipelineStorage":
|
def child(self, name: str | None) -> "PipelineStorage":
|
||||||
"""Create a child storage instance."""
|
"""Create a child storage instance."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def keys(self) -> list[str]:
|
||||||
|
"""List all keys in the storage."""
|
||||||
|
|||||||
@ -5,12 +5,10 @@
|
|||||||
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from datashaper import (
|
from datashaper import (
|
||||||
AsyncType,
|
AsyncType,
|
||||||
Table,
|
Table,
|
||||||
VerbCallbacks,
|
VerbCallbacks,
|
||||||
VerbInput,
|
|
||||||
verb,
|
verb,
|
||||||
)
|
)
|
||||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
@ -27,10 +25,10 @@ from graphrag.index.storage import PipelineStorage
|
|||||||
treats_input_tables_as_immutable=True,
|
treats_input_tables_as_immutable=True,
|
||||||
)
|
)
|
||||||
async def create_base_entity_graph(
|
async def create_base_entity_graph(
|
||||||
input: VerbInput,
|
|
||||||
callbacks: VerbCallbacks,
|
callbacks: VerbCallbacks,
|
||||||
cache: PipelineCache,
|
cache: PipelineCache,
|
||||||
storage: PipelineStorage,
|
storage: PipelineStorage,
|
||||||
|
runtime_storage: PipelineStorage,
|
||||||
text_column: str,
|
text_column: str,
|
||||||
id_column: str,
|
id_column: str,
|
||||||
clustering_strategy: dict[str, Any],
|
clustering_strategy: dict[str, Any],
|
||||||
@ -48,10 +46,10 @@ async def create_base_entity_graph(
|
|||||||
**_kwargs: dict,
|
**_kwargs: dict,
|
||||||
) -> VerbResult:
|
) -> VerbResult:
|
||||||
"""All the steps to create the base entity graph."""
|
"""All the steps to create the base entity graph."""
|
||||||
source = cast(pd.DataFrame, input.get_input())
|
text_units = await runtime_storage.get("base_text_units")
|
||||||
|
|
||||||
output = await create_base_entity_graph_flow(
|
output = await create_base_entity_graph_flow(
|
||||||
source,
|
text_units,
|
||||||
callbacks,
|
callbacks,
|
||||||
cache,
|
cache,
|
||||||
storage,
|
storage,
|
||||||
|
|||||||
@ -17,12 +17,14 @@ from datashaper.table_store.types import VerbResult, create_verb_result
|
|||||||
from graphrag.index.flows.create_base_text_units import (
|
from graphrag.index.flows.create_base_text_units import (
|
||||||
create_base_text_units as create_base_text_units_flow,
|
create_base_text_units as create_base_text_units_flow,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.storage import PipelineStorage
|
||||||
|
|
||||||
|
|
||||||
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
|
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
|
||||||
def create_base_text_units(
|
async def create_base_text_units(
|
||||||
input: VerbInput,
|
input: VerbInput,
|
||||||
callbacks: VerbCallbacks,
|
callbacks: VerbCallbacks,
|
||||||
|
runtime_storage: PipelineStorage,
|
||||||
chunk_column_name: str,
|
chunk_column_name: str,
|
||||||
n_tokens_column_name: str,
|
n_tokens_column_name: str,
|
||||||
chunk_by_columns: list[str],
|
chunk_by_columns: list[str],
|
||||||
@ -41,9 +43,11 @@ def create_base_text_units(
|
|||||||
chunk_strategy=chunk_strategy,
|
chunk_strategy=chunk_strategy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await runtime_storage.set("base_text_units", output)
|
||||||
|
|
||||||
return create_verb_result(
|
return create_verb_result(
|
||||||
cast(
|
cast(
|
||||||
Table,
|
Table,
|
||||||
output,
|
pd.DataFrame(),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -5,12 +5,10 @@
|
|||||||
|
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from datashaper import (
|
from datashaper import (
|
||||||
AsyncType,
|
AsyncType,
|
||||||
Table,
|
Table,
|
||||||
VerbCallbacks,
|
VerbCallbacks,
|
||||||
VerbInput,
|
|
||||||
verb,
|
verb,
|
||||||
)
|
)
|
||||||
from datashaper.table_store.types import VerbResult, create_verb_result
|
from datashaper.table_store.types import VerbResult, create_verb_result
|
||||||
@ -19,13 +17,14 @@ from graphrag.index.cache import PipelineCache
|
|||||||
from graphrag.index.flows.create_final_covariates import (
|
from graphrag.index.flows.create_final_covariates import (
|
||||||
create_final_covariates as create_final_covariates_flow,
|
create_final_covariates as create_final_covariates_flow,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.storage import PipelineStorage
|
||||||
|
|
||||||
|
|
||||||
@verb(name="create_final_covariates", treats_input_tables_as_immutable=True)
|
@verb(name="create_final_covariates", treats_input_tables_as_immutable=True)
|
||||||
async def create_final_covariates(
|
async def create_final_covariates(
|
||||||
input: VerbInput,
|
|
||||||
callbacks: VerbCallbacks,
|
callbacks: VerbCallbacks,
|
||||||
cache: PipelineCache,
|
cache: PipelineCache,
|
||||||
|
runtime_storage: PipelineStorage,
|
||||||
column: str,
|
column: str,
|
||||||
covariate_type: str,
|
covariate_type: str,
|
||||||
extraction_strategy: dict[str, Any] | None,
|
extraction_strategy: dict[str, Any] | None,
|
||||||
@ -35,10 +34,10 @@ async def create_final_covariates(
|
|||||||
**_kwargs: dict,
|
**_kwargs: dict,
|
||||||
) -> VerbResult:
|
) -> VerbResult:
|
||||||
"""All the steps to extract and format covariates."""
|
"""All the steps to extract and format covariates."""
|
||||||
source = cast(pd.DataFrame, input.get_input())
|
text_units = await runtime_storage.get("base_text_units")
|
||||||
|
|
||||||
output = await create_final_covariates_flow(
|
output = await create_final_covariates_flow(
|
||||||
source,
|
text_units,
|
||||||
callbacks,
|
callbacks,
|
||||||
cache,
|
cache,
|
||||||
column,
|
column,
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from graphrag.index.cache import PipelineCache
|
|||||||
from graphrag.index.flows.create_final_text_units import (
|
from graphrag.index.flows.create_final_text_units import (
|
||||||
create_final_text_units as create_final_text_units_flow,
|
create_final_text_units as create_final_text_units_flow,
|
||||||
)
|
)
|
||||||
|
from graphrag.index.storage import PipelineStorage
|
||||||
from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table
|
from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table
|
||||||
|
|
||||||
|
|
||||||
@ -27,11 +28,12 @@ async def create_final_text_units(
|
|||||||
input: VerbInput,
|
input: VerbInput,
|
||||||
callbacks: VerbCallbacks,
|
callbacks: VerbCallbacks,
|
||||||
cache: PipelineCache,
|
cache: PipelineCache,
|
||||||
|
runtime_storage: PipelineStorage,
|
||||||
text_text_embed: dict | None = None,
|
text_text_embed: dict | None = None,
|
||||||
**_kwargs: dict,
|
**_kwargs: dict,
|
||||||
) -> VerbResult:
|
) -> VerbResult:
|
||||||
"""All the steps to transform the text units."""
|
"""All the steps to transform the text units."""
|
||||||
source = cast(pd.DataFrame, input.get_input())
|
text_units = await runtime_storage.get("base_text_units")
|
||||||
final_entities = cast(
|
final_entities = cast(
|
||||||
pd.DataFrame, get_required_input_table(input, "entities").table
|
pd.DataFrame, get_required_input_table(input, "entities").table
|
||||||
)
|
)
|
||||||
@ -44,7 +46,7 @@ async def create_final_text_units(
|
|||||||
final_covariates = cast(pd.DataFrame, final_covariates.table)
|
final_covariates = cast(pd.DataFrame, final_covariates.table)
|
||||||
|
|
||||||
output = await create_final_text_units_flow(
|
output = await create_final_text_units_flow(
|
||||||
source,
|
text_units,
|
||||||
final_entities,
|
final_entities,
|
||||||
final_relationships,
|
final_relationships,
|
||||||
final_covariates,
|
final_covariates,
|
||||||
|
|||||||
2
tests/fixtures/min-csv/config.json
vendored
2
tests/fixtures/min-csv/config.json
vendored
@ -16,7 +16,7 @@
|
|||||||
2500
|
2500
|
||||||
],
|
],
|
||||||
"subworkflows": 1,
|
"subworkflows": 1,
|
||||||
"max_runtime": 100
|
"max_runtime": 300
|
||||||
},
|
},
|
||||||
"create_final_entities": {
|
"create_final_entities": {
|
||||||
"row_range": [
|
"row_range": [
|
||||||
|
|||||||
@ -178,19 +178,8 @@ class TestIndexer:
|
|||||||
workflows == expected_workflows
|
workflows == expected_workflows
|
||||||
), f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}"
|
), f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}"
|
||||||
|
|
||||||
# [OPTIONAL] Check subworkflows
|
# [OPTIONAL] Check runtime
|
||||||
for workflow in expected_workflows:
|
for workflow in expected_workflows:
|
||||||
if "subworkflows" in workflow_config[workflow]:
|
|
||||||
# Check number of subworkflows
|
|
||||||
subworkflows = stats["workflows"][workflow]
|
|
||||||
expected_subworkflows = workflow_config[workflow].get(
|
|
||||||
"subworkflows", None
|
|
||||||
)
|
|
||||||
if expected_subworkflows:
|
|
||||||
assert (
|
|
||||||
len(subworkflows) - 1 == expected_subworkflows
|
|
||||||
), f"Expected {expected_subworkflows} subworkflows, found: {len(subworkflows) - 1} for workflow: {workflow}: [{subworkflows}]"
|
|
||||||
|
|
||||||
# Check max runtime
|
# Check max runtime
|
||||||
max_runtime = workflow_config[workflow].get("max_runtime", None)
|
max_runtime = workflow_config[workflow].get("max_runtime", None)
|
||||||
if max_runtime:
|
if max_runtime:
|
||||||
@ -200,8 +189,15 @@ class TestIndexer:
|
|||||||
|
|
||||||
# Check artifacts
|
# Check artifacts
|
||||||
artifact_files = os.listdir(artifacts)
|
artifact_files = os.listdir(artifacts)
|
||||||
|
# check that the number of workflows matches the number of artifacts, but:
|
||||||
|
# (1) do not count workflows with only transient output
|
||||||
|
# (2) account for the stats.json file
|
||||||
|
transient_workflows = [
|
||||||
|
"workflow:create_base_text_units",
|
||||||
|
]
|
||||||
assert (
|
assert (
|
||||||
len(artifact_files) == len(expected_workflows) + 1
|
len(artifact_files)
|
||||||
|
== (len(expected_workflows) - len(transient_workflows) + 1)
|
||||||
), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}"
|
), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}"
|
||||||
|
|
||||||
for artifact in artifact_files:
|
for artifact in artifact_files:
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import networkx as nx
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from graphrag.config.enums import LLMType
|
from graphrag.config.enums import LLMType
|
||||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
from graphrag.index.run.utils import create_run_context
|
||||||
from graphrag.index.workflows.v1.create_base_entity_graph import (
|
from graphrag.index.workflows.v1.create_base_entity_graph import (
|
||||||
build_steps,
|
build_steps,
|
||||||
workflow_name,
|
workflow_name,
|
||||||
@ -55,7 +55,10 @@ async def test_create_base_entity_graph():
|
|||||||
])
|
])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
storage = MemoryPipelineStorage()
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||||
@ -68,7 +71,7 @@ async def test_create_base_entity_graph():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
storage=storage,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(actual.columns) == len(
|
assert len(actual.columns) == len(
|
||||||
@ -88,7 +91,7 @@ async def test_create_base_entity_graph():
|
|||||||
nodes = list(actual_graph_0.nodes(data=True))
|
nodes = list(actual_graph_0.nodes(data=True))
|
||||||
assert nodes[0][1]["description"] == "Company_A is a test company"
|
assert nodes[0][1]["description"] == "Company_A is a test company"
|
||||||
|
|
||||||
assert len(storage.keys()) == 0, "Storage should be empty"
|
assert len(context.storage.keys()) == 0, "Storage should be empty"
|
||||||
|
|
||||||
|
|
||||||
async def test_create_base_entity_graph_with_embeddings():
|
async def test_create_base_entity_graph_with_embeddings():
|
||||||
@ -97,6 +100,11 @@ async def test_create_base_entity_graph_with_embeddings():
|
|||||||
])
|
])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||||
@ -110,6 +118,7 @@ async def test_create_base_entity_graph_with_embeddings():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
@ -123,7 +132,10 @@ async def test_create_base_entity_graph_with_snapshots():
|
|||||||
"workflow:create_base_text_units",
|
"workflow:create_base_text_units",
|
||||||
])
|
])
|
||||||
|
|
||||||
storage = MemoryPipelineStorage()
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
@ -140,10 +152,10 @@ async def test_create_base_entity_graph_with_snapshots():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
storage=storage,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert storage.keys() == [
|
assert context.storage.keys() == [
|
||||||
"raw_extracted_entities.json",
|
"raw_extracted_entities.json",
|
||||||
"merged_graph.graphml",
|
"merged_graph.graphml",
|
||||||
"summarized_graph.graphml",
|
"summarized_graph.graphml",
|
||||||
@ -157,6 +169,11 @@ async def test_create_base_entity_graph_missing_llm_throws():
|
|||||||
"workflow:create_base_text_units",
|
"workflow:create_base_text_units",
|
||||||
])
|
])
|
||||||
|
|
||||||
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
|
||||||
@ -170,4 +187,5 @@ async def test_create_base_entity_graph_missing_llm_throws():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
# Licensed under the MIT License
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
from graphrag.index.run.utils import create_run_context
|
||||||
from graphrag.index.workflows.v1.create_base_text_units import (
|
from graphrag.index.workflows.v1.create_base_text_units import (
|
||||||
build_steps,
|
build_steps,
|
||||||
workflow_name,
|
workflow_name,
|
||||||
@ -19,17 +20,21 @@ async def test_create_base_text_units():
|
|||||||
input_tables = load_input_tables(inputs=[])
|
input_tables = load_input_tables(inputs=[])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
context = create_run_context(None, None, None)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
# test data was created with 4o, so we need to match the encoding for chunks to be identical
|
# test data was created with 4o, so we need to match the encoding for chunks to be identical
|
||||||
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
|
config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
|
||||||
|
|
||||||
steps = build_steps(config)
|
steps = build_steps(config)
|
||||||
|
|
||||||
actual = await get_workflow_output(
|
await get_workflow_output(
|
||||||
input_tables,
|
input_tables,
|
||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
|
context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual = await context.runtime_storage.get("base_text_units")
|
||||||
compare_outputs(actual, expected)
|
compare_outputs(actual, expected)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from datashaper.errors import VerbParallelizationError
|
|||||||
from pandas.testing import assert_series_equal
|
from pandas.testing import assert_series_equal
|
||||||
|
|
||||||
from graphrag.config.enums import LLMType
|
from graphrag.config.enums import LLMType
|
||||||
|
from graphrag.index.run.utils import create_run_context
|
||||||
from graphrag.index.workflows.v1.create_final_covariates import (
|
from graphrag.index.workflows.v1.create_final_covariates import (
|
||||||
build_steps,
|
build_steps,
|
||||||
workflow_name,
|
workflow_name,
|
||||||
@ -31,6 +32,11 @@ async def test_create_final_covariates():
|
|||||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
config["claim_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
|
config["claim_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
|
||||||
@ -42,6 +48,7 @@ async def test_create_final_covariates():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
|
context,
|
||||||
)
|
)
|
||||||
|
|
||||||
input = input_tables["workflow:create_base_text_units"]
|
input = input_tables["workflow:create_base_text_units"]
|
||||||
@ -81,6 +88,11 @@ async def test_create_final_covariates():
|
|||||||
async def test_create_final_covariates_missing_llm_throws():
|
async def test_create_final_covariates_missing_llm_throws():
|
||||||
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
input_tables = load_input_tables(["workflow:create_base_text_units"])
|
||||||
|
|
||||||
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
del config["claim_extract"]["strategy"]["llm"]
|
del config["claim_extract"]["strategy"]["llm"]
|
||||||
@ -93,4 +105,5 @@ async def test_create_final_covariates_missing_llm_throws():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
|
context,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
# Licensed under the MIT License
|
# Licensed under the MIT License
|
||||||
|
|
||||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
|
from graphrag.index.run.utils import create_run_context
|
||||||
from graphrag.index.workflows.v1.create_final_nodes import (
|
from graphrag.index.workflows.v1.create_final_nodes import (
|
||||||
build_steps,
|
build_steps,
|
||||||
workflow_name,
|
workflow_name,
|
||||||
@ -22,7 +22,7 @@ async def test_create_final_nodes():
|
|||||||
])
|
])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
storage = MemoryPipelineStorage()
|
context = create_run_context(None, None, None)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
@ -37,12 +37,12 @@ async def test_create_final_nodes():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
storage=storage,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
compare_outputs(actual, expected)
|
compare_outputs(actual, expected)
|
||||||
|
|
||||||
assert len(storage.keys()) == 0, "Storage should be empty"
|
assert len(context.storage.keys()) == 0, "Storage should be empty"
|
||||||
|
|
||||||
|
|
||||||
async def test_create_final_nodes_with_snapshot():
|
async def test_create_final_nodes_with_snapshot():
|
||||||
@ -51,7 +51,7 @@ async def test_create_final_nodes_with_snapshot():
|
|||||||
])
|
])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
storage = MemoryPipelineStorage()
|
context = create_run_context(None, None, None)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
@ -67,11 +67,11 @@ async def test_create_final_nodes_with_snapshot():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
storage=storage,
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
assert actual.shape == expected.shape, "Graph dataframe shapes differ"
|
||||||
|
|
||||||
assert storage.keys() == [
|
assert context.storage.keys() == [
|
||||||
"top_level_nodes.json",
|
"top_level_nodes.json",
|
||||||
], "Graph snapshot keys differ"
|
], "Graph snapshot keys differ"
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
# Licensed under the MIT License
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
from graphrag.index.run.utils import create_run_context
|
||||||
from graphrag.index.workflows.v1.create_final_text_units import (
|
from graphrag.index.workflows.v1.create_final_text_units import (
|
||||||
build_steps,
|
build_steps,
|
||||||
workflow_name,
|
workflow_name,
|
||||||
@ -24,6 +25,11 @@ async def test_create_final_text_units():
|
|||||||
])
|
])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
config["covariates_enabled"] = True
|
config["covariates_enabled"] = True
|
||||||
@ -36,6 +42,7 @@ async def test_create_final_text_units():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
compare_outputs(actual, expected)
|
compare_outputs(actual, expected)
|
||||||
@ -50,6 +57,11 @@ async def test_create_final_text_units_no_covariates():
|
|||||||
])
|
])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
config["covariates_enabled"] = False
|
config["covariates_enabled"] = False
|
||||||
@ -62,6 +74,7 @@ async def test_create_final_text_units_no_covariates():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# we're short a covariate_ids column
|
# we're short a covariate_ids column
|
||||||
@ -81,6 +94,11 @@ async def test_create_final_text_units_with_embeddings():
|
|||||||
])
|
])
|
||||||
expected = load_expected(workflow_name)
|
expected = load_expected(workflow_name)
|
||||||
|
|
||||||
|
context = create_run_context(None, None, None)
|
||||||
|
await context.runtime_storage.set(
|
||||||
|
"base_text_units", input_tables["workflow:create_base_text_units"]
|
||||||
|
)
|
||||||
|
|
||||||
config = get_config_for_workflow(workflow_name)
|
config = get_config_for_workflow(workflow_name)
|
||||||
|
|
||||||
config["covariates_enabled"] = True
|
config["covariates_enabled"] = True
|
||||||
@ -96,6 +114,7 @@ async def test_create_final_text_units_with_embeddings():
|
|||||||
{
|
{
|
||||||
"steps": steps,
|
"steps": steps,
|
||||||
},
|
},
|
||||||
|
context=context,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert "text_embedding" in actual.columns
|
assert "text_embedding" in actual.columns
|
||||||
|
|||||||
@ -12,8 +12,8 @@ from graphrag.index import (
|
|||||||
PipelineWorkflowConfig,
|
PipelineWorkflowConfig,
|
||||||
create_pipeline_config,
|
create_pipeline_config,
|
||||||
)
|
)
|
||||||
from graphrag.index.run.utils import _create_run_context
|
from graphrag.index.context import PipelineRunContext
|
||||||
from graphrag.index.storage.pipeline_storage import PipelineStorage
|
from graphrag.index.run.utils import create_run_context
|
||||||
|
|
||||||
pd.set_option("display.max_columns", None)
|
pd.set_option("display.max_columns", None)
|
||||||
|
|
||||||
@ -60,7 +60,7 @@ def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
|
|||||||
async def get_workflow_output(
|
async def get_workflow_output(
|
||||||
input_tables: dict[str, pd.DataFrame],
|
input_tables: dict[str, pd.DataFrame],
|
||||||
schema: dict,
|
schema: dict,
|
||||||
storage: PipelineStorage | None = None,
|
context: PipelineRunContext | None = None,
|
||||||
) -> pd.DataFrame:
|
) -> pd.DataFrame:
|
||||||
"""Pass in the input tables, the schema, and the output name"""
|
"""Pass in the input tables, the schema, and the output name"""
|
||||||
|
|
||||||
@ -70,9 +70,9 @@ async def get_workflow_output(
|
|||||||
input_tables=input_tables,
|
input_tables=input_tables,
|
||||||
)
|
)
|
||||||
|
|
||||||
context = _create_run_context(storage, None, None)
|
run_context = context or create_run_context(None, None, None)
|
||||||
|
|
||||||
await workflow.run(context=context)
|
await workflow.run(context=run_context)
|
||||||
|
|
||||||
# if there's only one output, it is the default here, no name required
|
# if there's only one output, it is the default here, no name required
|
||||||
return cast(pd.DataFrame, workflow.output())
|
return cast(pd.DataFrame, workflow.output())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user