mirror of
				https://github.com/microsoft/graphrag.git
				synced 2025-11-04 03:39:55 +00:00 
			
		
		
		
	Rework workflow architecture (#1311)
* Rename pipeline_storage file * Add runtime storage option to context * Fix import * Switch to memory storage for runtime * Infra for workflow runtime storage * Migrate base_text_units to runtime storage * Fix comment * Semver * Remove whitespace * Remove subflow smoke tests and ignore transient artifacts * Remove entity graph from transient list (not yet implemented) * Increase smoke runtime allotment for create_base_entity_graph * Revert format fix * Remove noqa
This commit is contained in:
		
							parent
							
								
									ac09e0a740
								
							
						
					
					
						commit
						94f1e62e5c
					
				@ -0,0 +1,4 @@
 | 
			
		||||
{
 | 
			
		||||
  "type": "patch",
 | 
			
		||||
  "description": "Add runtime-only storage option."
 | 
			
		||||
}
 | 
			
		||||
@ -34,7 +34,11 @@ class PipelineRunContext:
 | 
			
		||||
 | 
			
		||||
    stats: PipelineRunStats
 | 
			
		||||
    storage: PipelineStorage
 | 
			
		||||
    "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
 | 
			
		||||
    cache: PipelineCache
 | 
			
		||||
    "Cache instance for reading previous LLM responses."
 | 
			
		||||
    runtime_storage: PipelineStorage
 | 
			
		||||
    "Runtime only storage for pipeline verbs to use. Items written here will only live in memory during the current run."
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO: For now, just has the same props available to it
 | 
			
		||||
 | 
			
		||||
@ -69,7 +69,3 @@ async def _write_workflow_stats(
 | 
			
		||||
        await _save_profiler_stats(
 | 
			
		||||
            storage, workflow.name, workflow_result.memory_profile
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    log.debug(
 | 
			
		||||
        "first row of %s => %s", workflow.name, workflow.output().iloc[0].to_json()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -32,8 +32,8 @@ from graphrag.index.run.utils import (
 | 
			
		||||
    _apply_substitutions,
 | 
			
		||||
    _create_input,
 | 
			
		||||
    _create_reporter,
 | 
			
		||||
    _create_run_context,
 | 
			
		||||
    _validate_dataset,
 | 
			
		||||
    create_run_context,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.run.workflow import (
 | 
			
		||||
    _create_callback_chain,
 | 
			
		||||
@ -200,7 +200,7 @@ async def run_pipeline(
 | 
			
		||||
    """
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
 | 
			
		||||
    context = _create_run_context(storage=storage, cache=cache, stats=None)
 | 
			
		||||
    context = create_run_context(storage=storage, cache=cache, stats=None)
 | 
			
		||||
 | 
			
		||||
    progress_reporter = progress_reporter or NullProgressReporter()
 | 
			
		||||
    callbacks = callbacks or ConsoleWorkflowCallbacks()
 | 
			
		||||
 | 
			
		||||
@ -103,7 +103,7 @@ def _apply_substitutions(config: PipelineConfig, run_id: str) -> PipelineConfig:
 | 
			
		||||
    return config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _create_run_context(
 | 
			
		||||
def create_run_context(
 | 
			
		||||
    storage: PipelineStorage | None,
 | 
			
		||||
    cache: PipelineCache | None,
 | 
			
		||||
    stats: PipelineRunStats | None,
 | 
			
		||||
@ -113,4 +113,5 @@ def _create_run_context(
 | 
			
		||||
        stats=stats or PipelineRunStats(),
 | 
			
		||||
        cache=cache or InMemoryCache(),
 | 
			
		||||
        storage=storage or MemoryPipelineStorage(),
 | 
			
		||||
        runtime_storage=MemoryPipelineStorage(),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,18 @@ async def _inject_workflow_data_dependencies(
 | 
			
		||||
    log.info("dependencies for %s: %s", workflow.name, deps)
 | 
			
		||||
    for id in deps:
 | 
			
		||||
        workflow_id = f"workflow:{id}"
 | 
			
		||||
        table = await _load_table_from_storage(f"{id}.parquet", storage)
 | 
			
		||||
        try:
 | 
			
		||||
            table = await _load_table_from_storage(f"{id}.parquet", storage)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            # our workflows now allow transient tables, and we avoid putting those in primary storage
 | 
			
		||||
            # however, we need to keep the table in the dependency list for proper execution order
 | 
			
		||||
            # this allows us to catch missing table errors but emit a warning for pipeline users who may genuinely have an error (which we expect to be very rare)
 | 
			
		||||
            # todo: this issue will resolve itself if we remove DataShaper completely
 | 
			
		||||
            log.warning(
 | 
			
		||||
                "Dependency table %s not found in storage: it may be a runtime-only in-memory table. If you see further errors, this may be an actual problem.",
 | 
			
		||||
                id,
 | 
			
		||||
            )
 | 
			
		||||
            table = pd.DataFrame()
 | 
			
		||||
        workflow.add_table(workflow_id, table)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -276,6 +276,11 @@ class BlobPipelineStorage(PipelineStorage):
 | 
			
		||||
            self._storage_account_blob_url,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def keys(self) -> list[str]:
 | 
			
		||||
        """Return the keys in the storage."""
 | 
			
		||||
        msg = "Blob storage does yet not support listing keys."
 | 
			
		||||
        raise NotImplementedError(msg)
 | 
			
		||||
 | 
			
		||||
    def _keyname(self, key: str) -> str:
 | 
			
		||||
        """Get the key name."""
 | 
			
		||||
        return str(Path(self._path_prefix) / key)
 | 
			
		||||
 | 
			
		||||
@ -143,6 +143,10 @@ class FilePipelineStorage(PipelineStorage):
 | 
			
		||||
            return self
 | 
			
		||||
        return FilePipelineStorage(str(Path(self._root_dir) / Path(name)))
 | 
			
		||||
 | 
			
		||||
    def keys(self) -> list[str]:
 | 
			
		||||
        """Return the keys in the storage."""
 | 
			
		||||
        return os.listdir(self._root_dir)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def join_path(file_path: str, file_name: str) -> Path:
 | 
			
		||||
    """Join a path and a file. Independent of the OS."""
 | 
			
		||||
 | 
			
		||||
@ -32,7 +32,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
 | 
			
		||||
        -------
 | 
			
		||||
            - output - The value for the given key.
 | 
			
		||||
        """
 | 
			
		||||
        return self._storage.get(key) or await super().get(key, as_bytes, encoding)
 | 
			
		||||
        return self._storage.get(key)
 | 
			
		||||
 | 
			
		||||
    async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
 | 
			
		||||
        """Set the value for the given key.
 | 
			
		||||
@ -53,7 +53,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
 | 
			
		||||
        -------
 | 
			
		||||
            - output - True if the key exists in the storage, False otherwise.
 | 
			
		||||
        """
 | 
			
		||||
        return key in self._storage or await super().has(key)
 | 
			
		||||
        return key in self._storage
 | 
			
		||||
 | 
			
		||||
    async def delete(self, key: str) -> None:
 | 
			
		||||
        """Delete the given key from the storage.
 | 
			
		||||
@ -69,7 +69,7 @@ class MemoryPipelineStorage(FilePipelineStorage):
 | 
			
		||||
 | 
			
		||||
    def child(self, name: str | None) -> "PipelineStorage":
 | 
			
		||||
        """Create a child storage instance."""
 | 
			
		||||
        return self
 | 
			
		||||
        return MemoryPipelineStorage()
 | 
			
		||||
 | 
			
		||||
    def keys(self) -> list[str]:
 | 
			
		||||
        """Return the keys in the storage."""
 | 
			
		||||
 | 
			
		||||
@ -76,3 +76,7 @@ class PipelineStorage(metaclass=ABCMeta):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def child(self, name: str | None) -> "PipelineStorage":
 | 
			
		||||
        """Create a child storage instance."""
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def keys(self) -> list[str]:
 | 
			
		||||
        """List all keys in the storage."""
 | 
			
		||||
 | 
			
		||||
@ -5,12 +5,10 @@
 | 
			
		||||
 | 
			
		||||
from typing import Any, cast
 | 
			
		||||
 | 
			
		||||
import pandas as pd
 | 
			
		||||
from datashaper import (
 | 
			
		||||
    AsyncType,
 | 
			
		||||
    Table,
 | 
			
		||||
    VerbCallbacks,
 | 
			
		||||
    VerbInput,
 | 
			
		||||
    verb,
 | 
			
		||||
)
 | 
			
		||||
from datashaper.table_store.types import VerbResult, create_verb_result
 | 
			
		||||
@ -27,10 +25,10 @@ from graphrag.index.storage import PipelineStorage
 | 
			
		||||
    treats_input_tables_as_immutable=True,
 | 
			
		||||
)
 | 
			
		||||
async def create_base_entity_graph(
 | 
			
		||||
    input: VerbInput,
 | 
			
		||||
    callbacks: VerbCallbacks,
 | 
			
		||||
    cache: PipelineCache,
 | 
			
		||||
    storage: PipelineStorage,
 | 
			
		||||
    runtime_storage: PipelineStorage,
 | 
			
		||||
    text_column: str,
 | 
			
		||||
    id_column: str,
 | 
			
		||||
    clustering_strategy: dict[str, Any],
 | 
			
		||||
@ -48,10 +46,10 @@ async def create_base_entity_graph(
 | 
			
		||||
    **_kwargs: dict,
 | 
			
		||||
) -> VerbResult:
 | 
			
		||||
    """All the steps to create the base entity graph."""
 | 
			
		||||
    source = cast(pd.DataFrame, input.get_input())
 | 
			
		||||
    text_units = await runtime_storage.get("base_text_units")
 | 
			
		||||
 | 
			
		||||
    output = await create_base_entity_graph_flow(
 | 
			
		||||
        source,
 | 
			
		||||
        text_units,
 | 
			
		||||
        callbacks,
 | 
			
		||||
        cache,
 | 
			
		||||
        storage,
 | 
			
		||||
 | 
			
		||||
@ -17,12 +17,14 @@ from datashaper.table_store.types import VerbResult, create_verb_result
 | 
			
		||||
from graphrag.index.flows.create_base_text_units import (
 | 
			
		||||
    create_base_text_units as create_base_text_units_flow,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.storage import PipelineStorage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@verb(name="create_base_text_units", treats_input_tables_as_immutable=True)
 | 
			
		||||
def create_base_text_units(
 | 
			
		||||
async def create_base_text_units(
 | 
			
		||||
    input: VerbInput,
 | 
			
		||||
    callbacks: VerbCallbacks,
 | 
			
		||||
    runtime_storage: PipelineStorage,
 | 
			
		||||
    chunk_column_name: str,
 | 
			
		||||
    n_tokens_column_name: str,
 | 
			
		||||
    chunk_by_columns: list[str],
 | 
			
		||||
@ -41,9 +43,11 @@ def create_base_text_units(
 | 
			
		||||
        chunk_strategy=chunk_strategy,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    await runtime_storage.set("base_text_units", output)
 | 
			
		||||
 | 
			
		||||
    return create_verb_result(
 | 
			
		||||
        cast(
 | 
			
		||||
            Table,
 | 
			
		||||
            output,
 | 
			
		||||
            pd.DataFrame(),
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -5,12 +5,10 @@
 | 
			
		||||
 | 
			
		||||
from typing import Any, cast
 | 
			
		||||
 | 
			
		||||
import pandas as pd
 | 
			
		||||
from datashaper import (
 | 
			
		||||
    AsyncType,
 | 
			
		||||
    Table,
 | 
			
		||||
    VerbCallbacks,
 | 
			
		||||
    VerbInput,
 | 
			
		||||
    verb,
 | 
			
		||||
)
 | 
			
		||||
from datashaper.table_store.types import VerbResult, create_verb_result
 | 
			
		||||
@ -19,13 +17,14 @@ from graphrag.index.cache import PipelineCache
 | 
			
		||||
from graphrag.index.flows.create_final_covariates import (
 | 
			
		||||
    create_final_covariates as create_final_covariates_flow,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.storage import PipelineStorage
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@verb(name="create_final_covariates", treats_input_tables_as_immutable=True)
 | 
			
		||||
async def create_final_covariates(
 | 
			
		||||
    input: VerbInput,
 | 
			
		||||
    callbacks: VerbCallbacks,
 | 
			
		||||
    cache: PipelineCache,
 | 
			
		||||
    runtime_storage: PipelineStorage,
 | 
			
		||||
    column: str,
 | 
			
		||||
    covariate_type: str,
 | 
			
		||||
    extraction_strategy: dict[str, Any] | None,
 | 
			
		||||
@ -35,10 +34,10 @@ async def create_final_covariates(
 | 
			
		||||
    **_kwargs: dict,
 | 
			
		||||
) -> VerbResult:
 | 
			
		||||
    """All the steps to extract and format covariates."""
 | 
			
		||||
    source = cast(pd.DataFrame, input.get_input())
 | 
			
		||||
    text_units = await runtime_storage.get("base_text_units")
 | 
			
		||||
 | 
			
		||||
    output = await create_final_covariates_flow(
 | 
			
		||||
        source,
 | 
			
		||||
        text_units,
 | 
			
		||||
        callbacks,
 | 
			
		||||
        cache,
 | 
			
		||||
        column,
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,7 @@ from graphrag.index.cache import PipelineCache
 | 
			
		||||
from graphrag.index.flows.create_final_text_units import (
 | 
			
		||||
    create_final_text_units as create_final_text_units_flow,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.storage import PipelineStorage
 | 
			
		||||
from graphrag.index.utils.ds_util import get_named_input_table, get_required_input_table
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -27,11 +28,12 @@ async def create_final_text_units(
 | 
			
		||||
    input: VerbInput,
 | 
			
		||||
    callbacks: VerbCallbacks,
 | 
			
		||||
    cache: PipelineCache,
 | 
			
		||||
    runtime_storage: PipelineStorage,
 | 
			
		||||
    text_text_embed: dict | None = None,
 | 
			
		||||
    **_kwargs: dict,
 | 
			
		||||
) -> VerbResult:
 | 
			
		||||
    """All the steps to transform the text units."""
 | 
			
		||||
    source = cast(pd.DataFrame, input.get_input())
 | 
			
		||||
    text_units = await runtime_storage.get("base_text_units")
 | 
			
		||||
    final_entities = cast(
 | 
			
		||||
        pd.DataFrame, get_required_input_table(input, "entities").table
 | 
			
		||||
    )
 | 
			
		||||
@ -44,7 +46,7 @@ async def create_final_text_units(
 | 
			
		||||
        final_covariates = cast(pd.DataFrame, final_covariates.table)
 | 
			
		||||
 | 
			
		||||
    output = await create_final_text_units_flow(
 | 
			
		||||
        source,
 | 
			
		||||
        text_units,
 | 
			
		||||
        final_entities,
 | 
			
		||||
        final_relationships,
 | 
			
		||||
        final_covariates,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								tests/fixtures/min-csv/config.json
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								tests/fixtures/min-csv/config.json
									
									
									
									
										vendored
									
									
								
							@ -16,7 +16,7 @@
 | 
			
		||||
                2500
 | 
			
		||||
            ],
 | 
			
		||||
            "subworkflows": 1,
 | 
			
		||||
            "max_runtime": 100
 | 
			
		||||
            "max_runtime": 300
 | 
			
		||||
        },
 | 
			
		||||
        "create_final_entities": {
 | 
			
		||||
            "row_range": [
 | 
			
		||||
 | 
			
		||||
@ -178,30 +178,26 @@ class TestIndexer:
 | 
			
		||||
            workflows == expected_workflows
 | 
			
		||||
        ), f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}"
 | 
			
		||||
 | 
			
		||||
        # [OPTIONAL] Check subworkflows
 | 
			
		||||
        # [OPTIONAL] Check runtime
 | 
			
		||||
        for workflow in expected_workflows:
 | 
			
		||||
            if "subworkflows" in workflow_config[workflow]:
 | 
			
		||||
                # Check number of subworkflows
 | 
			
		||||
                subworkflows = stats["workflows"][workflow]
 | 
			
		||||
                expected_subworkflows = workflow_config[workflow].get(
 | 
			
		||||
                    "subworkflows", None
 | 
			
		||||
                )
 | 
			
		||||
                if expected_subworkflows:
 | 
			
		||||
                    assert (
 | 
			
		||||
                        len(subworkflows) - 1 == expected_subworkflows
 | 
			
		||||
                    ), f"Expected {expected_subworkflows} subworkflows, found: {len(subworkflows) - 1} for workflow: {workflow}: [{subworkflows}]"
 | 
			
		||||
 | 
			
		||||
                # Check max runtime
 | 
			
		||||
                max_runtime = workflow_config[workflow].get("max_runtime", None)
 | 
			
		||||
                if max_runtime:
 | 
			
		||||
                    assert (
 | 
			
		||||
                        stats["workflows"][workflow]["overall"] <= max_runtime
 | 
			
		||||
                    ), f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}"
 | 
			
		||||
            # Check max runtime
 | 
			
		||||
            max_runtime = workflow_config[workflow].get("max_runtime", None)
 | 
			
		||||
            if max_runtime:
 | 
			
		||||
                assert (
 | 
			
		||||
                    stats["workflows"][workflow]["overall"] <= max_runtime
 | 
			
		||||
                ), f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}"
 | 
			
		||||
 | 
			
		||||
        # Check artifacts
 | 
			
		||||
        artifact_files = os.listdir(artifacts)
 | 
			
		||||
        # check that the number of workflows matches the number of artifacts, but:
 | 
			
		||||
        # (1) do not count workflows with only transient output
 | 
			
		||||
        # (2) account for the stats.json file
 | 
			
		||||
        transient_workflows = [
 | 
			
		||||
            "workflow:create_base_text_units",
 | 
			
		||||
        ]
 | 
			
		||||
        assert (
 | 
			
		||||
            len(artifact_files) == len(expected_workflows) + 1
 | 
			
		||||
            len(artifact_files)
 | 
			
		||||
            == (len(expected_workflows) - len(transient_workflows) + 1)
 | 
			
		||||
        ), f"Expected {len(expected_workflows) + 1} artifacts, found: {len(artifact_files)}"
 | 
			
		||||
 | 
			
		||||
        for artifact in artifact_files:
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@ import networkx as nx
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from graphrag.config.enums import LLMType
 | 
			
		||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
 | 
			
		||||
from graphrag.index.run.utils import create_run_context
 | 
			
		||||
from graphrag.index.workflows.v1.create_base_entity_graph import (
 | 
			
		||||
    build_steps,
 | 
			
		||||
    workflow_name,
 | 
			
		||||
@ -55,7 +55,10 @@ async def test_create_base_entity_graph():
 | 
			
		||||
    ])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    storage = MemoryPipelineStorage()
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
    config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
 | 
			
		||||
@ -68,7 +71,7 @@ async def test_create_base_entity_graph():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        storage=storage,
 | 
			
		||||
        context=context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert len(actual.columns) == len(
 | 
			
		||||
@ -88,7 +91,7 @@ async def test_create_base_entity_graph():
 | 
			
		||||
    nodes = list(actual_graph_0.nodes(data=True))
 | 
			
		||||
    assert nodes[0][1]["description"] == "Company_A is a test company"
 | 
			
		||||
 | 
			
		||||
    assert len(storage.keys()) == 0, "Storage should be empty"
 | 
			
		||||
    assert len(context.storage.keys()) == 0, "Storage should be empty"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_create_base_entity_graph_with_embeddings():
 | 
			
		||||
@ -97,6 +100,11 @@ async def test_create_base_entity_graph_with_embeddings():
 | 
			
		||||
    ])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
    config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
 | 
			
		||||
@ -110,6 +118,7 @@ async def test_create_base_entity_graph_with_embeddings():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        context=context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert (
 | 
			
		||||
@ -123,7 +132,10 @@ async def test_create_base_entity_graph_with_snapshots():
 | 
			
		||||
        "workflow:create_base_text_units",
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
    storage = MemoryPipelineStorage()
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
@ -140,10 +152,10 @@ async def test_create_base_entity_graph_with_snapshots():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        storage=storage,
 | 
			
		||||
        context=context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert storage.keys() == [
 | 
			
		||||
    assert context.storage.keys() == [
 | 
			
		||||
        "raw_extracted_entities.json",
 | 
			
		||||
        "merged_graph.graphml",
 | 
			
		||||
        "summarized_graph.graphml",
 | 
			
		||||
@ -157,6 +169,11 @@ async def test_create_base_entity_graph_missing_llm_throws():
 | 
			
		||||
        "workflow:create_base_text_units",
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
    config["entity_extract"]["strategy"]["llm"] = MOCK_LLM_ENTITY_CONFIG
 | 
			
		||||
@ -170,4 +187,5 @@ async def test_create_base_entity_graph_missing_llm_throws():
 | 
			
		||||
            {
 | 
			
		||||
                "steps": steps,
 | 
			
		||||
            },
 | 
			
		||||
            context=context,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
# Copyright (c) 2024 Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License
 | 
			
		||||
 | 
			
		||||
from graphrag.index.run.utils import create_run_context
 | 
			
		||||
from graphrag.index.workflows.v1.create_base_text_units import (
 | 
			
		||||
    build_steps,
 | 
			
		||||
    workflow_name,
 | 
			
		||||
@ -19,17 +20,21 @@ async def test_create_base_text_units():
 | 
			
		||||
    input_tables = load_input_tables(inputs=[])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
    # test data was created with 4o, so we need to match the encoding for chunks to be identical
 | 
			
		||||
    config["text_chunk"]["strategy"]["encoding_name"] = "o200k_base"
 | 
			
		||||
 | 
			
		||||
    steps = build_steps(config)
 | 
			
		||||
 | 
			
		||||
    actual = await get_workflow_output(
 | 
			
		||||
    await get_workflow_output(
 | 
			
		||||
        input_tables,
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    actual = await context.runtime_storage.get("base_text_units")
 | 
			
		||||
    compare_outputs(actual, expected)
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@ from datashaper.errors import VerbParallelizationError
 | 
			
		||||
from pandas.testing import assert_series_equal
 | 
			
		||||
 | 
			
		||||
from graphrag.config.enums import LLMType
 | 
			
		||||
from graphrag.index.run.utils import create_run_context
 | 
			
		||||
from graphrag.index.workflows.v1.create_final_covariates import (
 | 
			
		||||
    build_steps,
 | 
			
		||||
    workflow_name,
 | 
			
		||||
@ -31,6 +32,11 @@ async def test_create_final_covariates():
 | 
			
		||||
    input_tables = load_input_tables(["workflow:create_base_text_units"])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
    config["claim_extract"]["strategy"]["llm"] = MOCK_LLM_CONFIG
 | 
			
		||||
@ -42,6 +48,7 @@ async def test_create_final_covariates():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    input = input_tables["workflow:create_base_text_units"]
 | 
			
		||||
@ -81,6 +88,11 @@ async def test_create_final_covariates():
 | 
			
		||||
async def test_create_final_covariates_missing_llm_throws():
 | 
			
		||||
    input_tables = load_input_tables(["workflow:create_base_text_units"])
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
    del config["claim_extract"]["strategy"]["llm"]
 | 
			
		||||
@ -93,4 +105,5 @@ async def test_create_final_covariates_missing_llm_throws():
 | 
			
		||||
            {
 | 
			
		||||
                "steps": steps,
 | 
			
		||||
            },
 | 
			
		||||
            context,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
# Copyright (c) 2024 Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License
 | 
			
		||||
 | 
			
		||||
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
 | 
			
		||||
from graphrag.index.run.utils import create_run_context
 | 
			
		||||
from graphrag.index.workflows.v1.create_final_nodes import (
 | 
			
		||||
    build_steps,
 | 
			
		||||
    workflow_name,
 | 
			
		||||
@ -22,7 +22,7 @@ async def test_create_final_nodes():
 | 
			
		||||
    ])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    storage = MemoryPipelineStorage()
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
@ -37,12 +37,12 @@ async def test_create_final_nodes():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        storage=storage,
 | 
			
		||||
        context=context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    compare_outputs(actual, expected)
 | 
			
		||||
 | 
			
		||||
    assert len(storage.keys()) == 0, "Storage should be empty"
 | 
			
		||||
    assert len(context.storage.keys()) == 0, "Storage should be empty"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_create_final_nodes_with_snapshot():
 | 
			
		||||
@ -51,7 +51,7 @@ async def test_create_final_nodes_with_snapshot():
 | 
			
		||||
    ])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    storage = MemoryPipelineStorage()
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
@ -67,11 +67,11 @@ async def test_create_final_nodes_with_snapshot():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        storage=storage,
 | 
			
		||||
        context=context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert actual.shape == expected.shape, "Graph dataframe shapes differ"
 | 
			
		||||
 | 
			
		||||
    assert storage.keys() == [
 | 
			
		||||
    assert context.storage.keys() == [
 | 
			
		||||
        "top_level_nodes.json",
 | 
			
		||||
    ], "Graph snapshot keys differ"
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,7 @@
 | 
			
		||||
# Copyright (c) 2024 Microsoft Corporation.
 | 
			
		||||
# Licensed under the MIT License
 | 
			
		||||
 | 
			
		||||
from graphrag.index.run.utils import create_run_context
 | 
			
		||||
from graphrag.index.workflows.v1.create_final_text_units import (
 | 
			
		||||
    build_steps,
 | 
			
		||||
    workflow_name,
 | 
			
		||||
@ -24,6 +25,11 @@ async def test_create_final_text_units():
 | 
			
		||||
    ])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
    config["covariates_enabled"] = True
 | 
			
		||||
@ -36,6 +42,7 @@ async def test_create_final_text_units():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        context=context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    compare_outputs(actual, expected)
 | 
			
		||||
@ -50,6 +57,11 @@ async def test_create_final_text_units_no_covariates():
 | 
			
		||||
    ])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
    config["covariates_enabled"] = False
 | 
			
		||||
@ -62,6 +74,7 @@ async def test_create_final_text_units_no_covariates():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        context=context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # we're short a covariate_ids column
 | 
			
		||||
@ -81,6 +94,11 @@ async def test_create_final_text_units_with_embeddings():
 | 
			
		||||
    ])
 | 
			
		||||
    expected = load_expected(workflow_name)
 | 
			
		||||
 | 
			
		||||
    context = create_run_context(None, None, None)
 | 
			
		||||
    await context.runtime_storage.set(
 | 
			
		||||
        "base_text_units", input_tables["workflow:create_base_text_units"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    config = get_config_for_workflow(workflow_name)
 | 
			
		||||
 | 
			
		||||
    config["covariates_enabled"] = True
 | 
			
		||||
@ -96,6 +114,7 @@ async def test_create_final_text_units_with_embeddings():
 | 
			
		||||
        {
 | 
			
		||||
            "steps": steps,
 | 
			
		||||
        },
 | 
			
		||||
        context=context,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    assert "text_embedding" in actual.columns
 | 
			
		||||
 | 
			
		||||
@ -12,8 +12,8 @@ from graphrag.index import (
 | 
			
		||||
    PipelineWorkflowConfig,
 | 
			
		||||
    create_pipeline_config,
 | 
			
		||||
)
 | 
			
		||||
from graphrag.index.run.utils import _create_run_context
 | 
			
		||||
from graphrag.index.storage.pipeline_storage import PipelineStorage
 | 
			
		||||
from graphrag.index.context import PipelineRunContext
 | 
			
		||||
from graphrag.index.run.utils import create_run_context
 | 
			
		||||
 | 
			
		||||
pd.set_option("display.max_columns", None)
 | 
			
		||||
 | 
			
		||||
@ -60,7 +60,7 @@ def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
 | 
			
		||||
async def get_workflow_output(
 | 
			
		||||
    input_tables: dict[str, pd.DataFrame],
 | 
			
		||||
    schema: dict,
 | 
			
		||||
    storage: PipelineStorage | None = None,
 | 
			
		||||
    context: PipelineRunContext | None = None,
 | 
			
		||||
) -> pd.DataFrame:
 | 
			
		||||
    """Pass in the input tables, the schema, and the output name"""
 | 
			
		||||
 | 
			
		||||
@ -70,9 +70,9 @@ async def get_workflow_output(
 | 
			
		||||
        input_tables=input_tables,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    context = _create_run_context(storage, None, None)
 | 
			
		||||
    run_context = context or create_run_context(None, None, None)
 | 
			
		||||
 | 
			
		||||
    await workflow.run(context=context)
 | 
			
		||||
    await workflow.run(context=run_context)
 | 
			
		||||
 | 
			
		||||
    # if there's only one output, it is the default here, no name required
 | 
			
		||||
    return cast(pd.DataFrame, workflow.output())
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user