diff --git a/.semversioner/next-release/minor-20250519234123676262.json b/.semversioner/next-release/minor-20250519234123676262.json new file mode 100644 index 00000000..4d3e2af9 --- /dev/null +++ b/.semversioner/next-release/minor-20250519234123676262.json @@ -0,0 +1,4 @@ +{ + "type": "minor", + "description": "Allow injection of custom pipelines." +} diff --git a/graphrag/api/index.py b/graphrag/api/index.py index f530bfa4..b4265542 100644 --- a/graphrag/api/index.py +++ b/graphrag/api/index.py @@ -27,7 +27,7 @@ log = logging.getLogger(__name__) async def build_index( config: GraphRagConfig, - method: IndexingMethod = IndexingMethod.Standard, + method: IndexingMethod | str = IndexingMethod.Standard, is_update_run: bool = False, memory_profile: bool = False, callbacks: list[WorkflowCallbacks] | None = None, @@ -65,7 +65,9 @@ async def build_index( if memory_profile: log.warning("New pipeline does not yet support memory profiling.") - pipeline = PipelineFactory.create_pipeline(config, method, is_update_run) + # todo: this could propagate out to the cli for better clarity, but will be a breaking api change + method = _get_method(method, is_update_run) + pipeline = PipelineFactory.create_pipeline(config, method) workflow_callbacks.pipeline_start(pipeline.names()) @@ -90,3 +92,8 @@ async def build_index( def register_workflow_function(name: str, workflow: WorkflowFunction): """Register a custom workflow function. You can then include the name in the settings.yaml workflows list.""" PipelineFactory.register(name, workflow) + + +def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str: + m = method.value if isinstance(method, IndexingMethod) else method + return f"{m}-update" if is_update_run else m diff --git a/graphrag/api/prompt_tune.py b/graphrag/api/prompt_tune.py index c6dbb81d..fb47c8c1 100644 --- a/graphrag/api/prompt_tune.py +++ b/graphrag/api/prompt_tune.py @@ -52,7 +52,6 @@ from graphrag.prompt_tune.types import DocSelectionType async def generate_indexing_prompts( config: GraphRagConfig, logger: ProgressLogger, - root: str, chunk_size: PositiveInt = graphrag_config_defaults.chunks.size, overlap: Annotated[ int, annotated_types.Gt(-1) @@ -93,7 +92,6 @@ async def generate_indexing_prompts( # Retrieve documents logger.info("Chunking documents...") doc_list = await load_docs_in_chunks( - root=root, config=config, limit=limit, select_method=selection_method, diff --git a/graphrag/cli/index.py b/graphrag/cli/index.py index 0bb2bb49..29991e06 100644 --- a/graphrag/cli/index.py +++ b/graphrag/cli/index.py @@ -80,7 +80,6 @@ def index_cli( cli_overrides["reporting.base_dir"] = str(output_dir) cli_overrides["update_index_output.base_dir"] = str(output_dir) config = load_config(root_dir, config_filepath, cli_overrides) - _run_index( config=config, method=method, diff --git a/graphrag/cli/prompt_tune.py b/graphrag/cli/prompt_tune.py index fa07d0fb..b531b34a 100644 --- a/graphrag/cli/prompt_tune.py +++ b/graphrag/cli/prompt_tune.py @@ -86,7 +86,6 @@ async def prompt_tune( prompts = await api.generate_indexing_prompts( config=graph_config, - root=str(root_path), logger=progress_logger, chunk_size=chunk_size, overlap=overlap, diff --git a/graphrag/config/defaults.py b/graphrag/config/defaults.py index 0379f110..4cac43c4 100644 --- a/graphrag/config/defaults.py +++ b/graphrag/config/defaults.py @@ -14,11 +14,10 @@ from graphrag.config.enums import ( CacheType, ChunkStrategyType, InputFileType, - InputType, ModelType, NounPhraseExtractorType, - OutputType, ReportingType, + StorageType, ) from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import ( EN_STOP_WORDS, @@ -234,16 +233,31 @@ class GlobalSearchDefaults: chat_model_id: str = DEFAULT_CHAT_MODEL_ID +@dataclass +class StorageDefaults: + """Default values for storage.""" + + type = StorageType.file + base_dir: str = DEFAULT_OUTPUT_BASE_DIR + connection_string: None = None + container_name: None = None + storage_account_blob_url: None = None + cosmosdb_account_url: None = None + + +@dataclass +class InputStorageDefaults(StorageDefaults): + """Default values for input storage.""" + + base_dir: str = "input" + + @dataclass class InputDefaults: """Default values for input.""" - type = InputType.file + storage: InputStorageDefaults = field(default_factory=InputStorageDefaults) file_type = InputFileType.text - base_dir: str = "input" - connection_string: None = None - storage_account_blob_url: None = None - container_name: None = None encoding: str = "utf-8" file_pattern: str = "" file_filter: None = None @@ -301,15 +315,10 @@ class LocalSearchDefaults: @dataclass -class OutputDefaults: +class OutputDefaults(StorageDefaults): """Default values for output.""" - type = OutputType.file base_dir: str = DEFAULT_OUTPUT_BASE_DIR - connection_string: None = None - container_name: None = None - storage_account_blob_url: None = None - cosmosdb_account_url: None = None @dataclass @@ -364,14 +373,10 @@ class UmapDefaults: @dataclass -class UpdateIndexOutputDefaults: +class UpdateIndexOutputDefaults(StorageDefaults): """Default values for update index output.""" - type = OutputType.file base_dir: str = "update_output" - connection_string: None = None - container_name: None = None - storage_account_blob_url: None = None @dataclass @@ -395,6 +400,7 @@ class GraphRagConfigDefaults: root_dir: str = "" models: dict = field(default_factory=dict) reporting: ReportingDefaults = field(default_factory=ReportingDefaults) + storage: StorageDefaults = field(default_factory=StorageDefaults) output: OutputDefaults = field(default_factory=OutputDefaults) outputs: None = None update_index_output: UpdateIndexOutputDefaults = field( diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index f3efdbd2..6c0c4ce1 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -42,20 +42,7 @@ class InputFileType(str, Enum): return f'"{self.value}"' -class InputType(str, Enum): - """The input type for the pipeline.""" - - file = "file" - """The file storage type.""" - blob = "blob" - """The blob storage type.""" - - def __repr__(self): - """Get a string representation.""" - return f'"{self.value}"' - - -class OutputType(str, Enum): +class StorageType(str, Enum): """The output type for the pipeline.""" file = "file" @@ -152,6 +139,10 @@ class IndexingMethod(str, Enum): """Traditional GraphRAG indexing, with all graph construction and summarization performed by a language model.""" Fast = "fast" """Fast indexing, using NLP for graph construction and language model for summarization.""" + StandardUpdate = "standard-update" + """Incremental update with standard indexing.""" + FastUpdate = "fast-update" + """Incremental update with fast indexing.""" class NounPhraseExtractorType(str, Enum): diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index 08559ffb..b58344a9 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -58,9 +58,11 @@ models: ### Input settings ### input: - type: {graphrag_config_defaults.input.type.value} # or blob + storage: + type: {graphrag_config_defaults.input.storage.type.value} # or blob + base_dir: "{graphrag_config_defaults.input.storage.base_dir}" file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json] - base_dir: "{graphrag_config_defaults.input.base_dir}" + chunks: size: {graphrag_config_defaults.chunks.size} diff --git a/graphrag/config/models/graph_rag_config.py b/graphrag/config/models/graph_rag_config.py index c4c5b780..cac321f0 100644 --- a/graphrag/config/models/graph_rag_config.py +++ b/graphrag/config/models/graph_rag_config.py @@ -26,10 +26,10 @@ from graphrag.config.models.global_search_config import GlobalSearchConfig from graphrag.config.models.input_config import InputConfig from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.config.models.local_search_config import LocalSearchConfig -from graphrag.config.models.output_config import OutputConfig from graphrag.config.models.prune_graph_config import PruneGraphConfig from graphrag.config.models.reporting_config import ReportingConfig from graphrag.config.models.snapshots_config import SnapshotsConfig +from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) @@ -102,21 +102,31 @@ class GraphRagConfig(BaseModel): else: self.input.file_pattern = f".*\\.{self.input.file_type.value}$" + def _validate_input_base_dir(self) -> None: + """Validate the input base directory.""" + if self.input.storage.type == defs.StorageType.file: + if self.input.storage.base_dir.strip() == "": + msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration." + raise ValueError(msg) + self.input.storage.base_dir = str( + (Path(self.root_dir) / self.input.storage.base_dir).resolve() + ) + chunks: ChunkingConfig = Field( description="The chunking configuration to use.", default=ChunkingConfig(), ) """The chunking configuration to use.""" - output: OutputConfig = Field( + output: StorageConfig = Field( description="The output configuration.", - default=OutputConfig(), + default=StorageConfig(), ) """The output configuration.""" def _validate_output_base_dir(self) -> None: """Validate the output base directory.""" - if self.output.type == defs.OutputType.file: + if self.output.type == defs.StorageType.file: if self.output.base_dir.strip() == "": msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) @@ -124,7 +134,7 @@ class GraphRagConfig(BaseModel): (Path(self.root_dir) / self.output.base_dir).resolve() ) - outputs: dict[str, OutputConfig] | None = Field( + outputs: dict[str, StorageConfig] | None = Field( description="A list of output configurations used for multi-index query.", default=graphrag_config_defaults.outputs, ) @@ -133,7 +143,7 @@ class GraphRagConfig(BaseModel): """Validate the outputs dict base directories.""" if self.outputs: for output in self.outputs.values(): - if output.type == defs.OutputType.file: + if output.type == defs.StorageType.file: if output.base_dir.strip() == "": msg = "Output base directory is required for file output. Please rerun `graphrag init` and set the output configuration." raise ValueError(msg) @@ -141,10 +151,9 @@ class GraphRagConfig(BaseModel): (Path(self.root_dir) / output.base_dir).resolve() ) - update_index_output: OutputConfig = Field( + update_index_output: StorageConfig = Field( description="The output configuration for the updated index.", - default=OutputConfig( - type=graphrag_config_defaults.update_index_output.type, + default=StorageConfig( base_dir=graphrag_config_defaults.update_index_output.base_dir, ), ) @@ -152,7 +161,7 @@ class GraphRagConfig(BaseModel): def _validate_update_index_output_base_dir(self) -> None: """Validate the update index output base directory.""" - if self.update_index_output.type == defs.OutputType.file: + if self.update_index_output.type == defs.StorageType.file: if self.update_index_output.base_dir.strip() == "": msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration." raise ValueError(msg) @@ -345,6 +354,7 @@ class GraphRagConfig(BaseModel): self._validate_root_dir() self._validate_models() self._validate_input_pattern() + self._validate_input_base_dir() self._validate_reporting_base_dir() self._validate_output_base_dir() self._validate_multi_output_base_dirs() diff --git a/graphrag/config/models/input_config.py b/graphrag/config/models/input_config.py index 526b6681..139dd90c 100644 --- a/graphrag/config/models/input_config.py +++ b/graphrag/config/models/input_config.py @@ -7,36 +7,23 @@ from pydantic import BaseModel, Field import graphrag.config.defaults as defs from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.enums import InputFileType, InputType +from graphrag.config.enums import InputFileType +from graphrag.config.models.storage_config import StorageConfig class InputConfig(BaseModel): """The default configuration section for Input.""" - type: InputType = Field( - description="The input type to use.", - default=graphrag_config_defaults.input.type, + storage: StorageConfig = Field( + description="The storage configuration to use for reading input documents.", + default=StorageConfig( + base_dir=graphrag_config_defaults.input.storage.base_dir, + ), ) file_type: InputFileType = Field( description="The input file type to use.", default=graphrag_config_defaults.input.file_type, ) - base_dir: str = Field( - description="The input base directory to use.", - default=graphrag_config_defaults.input.base_dir, - ) - connection_string: str | None = Field( - description="The azure blob storage connection string to use.", - default=graphrag_config_defaults.input.connection_string, - ) - storage_account_blob_url: str | None = Field( - description="The storage account blob url to use.", - default=graphrag_config_defaults.input.storage_account_blob_url, - ) - container_name: str | None = Field( - description="The azure blob storage container name to use.", - default=graphrag_config_defaults.input.container_name, - ) encoding: str = Field( description="The input file encoding to use.", default=defs.graphrag_config_defaults.input.encoding, diff --git a/graphrag/config/models/output_config.py b/graphrag/config/models/output_config.py deleted file mode 100644 index e38b137d..00000000 --- a/graphrag/config/models/output_config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2024 Microsoft Corporation. -# Licensed under the MIT License - -"""Parameterization settings for the default configuration.""" - -from pydantic import BaseModel, Field - -from graphrag.config.defaults import graphrag_config_defaults -from graphrag.config.enums import OutputType - - -class OutputConfig(BaseModel): - """The default configuration section for Output.""" - - type: OutputType = Field( - description="The output type to use.", - default=graphrag_config_defaults.output.type, - ) - base_dir: str = Field( - description="The base directory for the output.", - default=graphrag_config_defaults.output.base_dir, - ) - connection_string: str | None = Field( - description="The storage connection string to use.", - default=graphrag_config_defaults.output.connection_string, - ) - container_name: str | None = Field( - description="The storage container name to use.", - default=graphrag_config_defaults.output.container_name, - ) - storage_account_blob_url: str | None = Field( - description="The storage account blob url to use.", - default=graphrag_config_defaults.output.storage_account_blob_url, - ) - cosmosdb_account_url: str | None = Field( - description="The cosmosdb account url to use.", - default=graphrag_config_defaults.output.cosmosdb_account_url, - ) diff --git a/graphrag/config/models/storage_config.py b/graphrag/config/models/storage_config.py new file mode 100644 index 00000000..abd0936c --- /dev/null +++ b/graphrag/config/models/storage_config.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Parameterization settings for the default configuration.""" + +from pathlib import Path + +from pydantic import BaseModel, Field, field_validator + +from graphrag.config.defaults import graphrag_config_defaults +from graphrag.config.enums import StorageType + + +class StorageConfig(BaseModel): + """The default configuration section for storage.""" + + type: StorageType = Field( + description="The storage type to use.", + default=graphrag_config_defaults.storage.type, + ) + base_dir: str = Field( + description="The base directory for the output.", + default=graphrag_config_defaults.storage.base_dir, + ) + + # Validate the base dir for multiple OS (use Path) + # if not using a cloud storage type. + @field_validator("base_dir", mode="before") + @classmethod + def validate_base_dir(cls, value, info): + """Ensure that base_dir is a valid filesystem path when using local storage.""" + # info.data contains other field values, including 'type' + if info.data.get("type") != StorageType.file: + return value + return str(Path(value)) + + connection_string: str | None = Field( + description="The storage connection string to use.", + default=graphrag_config_defaults.storage.connection_string, + ) + container_name: str | None = Field( + description="The storage container name to use.", + default=graphrag_config_defaults.storage.container_name, + ) + storage_account_blob_url: str | None = Field( + description="The storage account blob url to use.", + default=graphrag_config_defaults.storage.storage_account_blob_url, + ) + cosmosdb_account_url: str | None = Field( + description="The cosmosdb account url to use.", + default=graphrag_config_defaults.storage.cosmosdb_account_url, + ) diff --git a/graphrag/index/input/csv.py b/graphrag/index/input/csv.py index a178c419..bcd42fec 100644 --- a/graphrag/index/input/csv.py +++ b/graphrag/index/input/csv.py @@ -22,7 +22,7 @@ async def load_csv( storage: PipelineStorage, ) -> pd.DataFrame: """Load csv inputs from a directory.""" - log.info("Loading csv files from %s", config.base_dir) + log.info("Loading csv files from %s", config.storage.base_dir) async def load_file(path: str, group: dict | None) -> pd.DataFrame: if group is None: diff --git a/graphrag/index/input/factory.py b/graphrag/index/input/factory.py index 7c1d54bb..096204b0 100644 --- a/graphrag/index/input/factory.py +++ b/graphrag/index/input/factory.py @@ -5,20 +5,18 @@ import logging from collections.abc import Awaitable, Callable -from pathlib import Path from typing import cast import pandas as pd -from graphrag.config.enums import InputFileType, InputType +from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig from graphrag.index.input.csv import load_csv from graphrag.index.input.json import load_json from graphrag.index.input.text import load_text from graphrag.logger.base import ProgressLogger from graphrag.logger.null_progress import NullProgressLogger -from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage -from graphrag.storage.file_pipeline_storage import FilePipelineStorage +from graphrag.storage.pipeline_storage import PipelineStorage log = logging.getLogger(__name__) loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = { @@ -30,43 +28,12 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = { async def create_input( config: InputConfig, + storage: PipelineStorage, progress_reporter: ProgressLogger | None = None, - root_dir: str | None = None, ) -> pd.DataFrame: """Instantiate input data for a pipeline.""" - root_dir = root_dir or "" - log.info("loading input from root_dir=%s", config.base_dir) progress_reporter = progress_reporter or NullProgressLogger() - match config.type: - case InputType.blob: - log.info("using blob storage input") - if config.container_name is None: - msg = "Container name required for blob storage" - raise ValueError(msg) - if ( - config.connection_string is None - and config.storage_account_blob_url is None - ): - msg = "Connection string or storage account blob url required for blob storage" - raise ValueError(msg) - storage = BlobPipelineStorage( - connection_string=config.connection_string, - storage_account_blob_url=config.storage_account_blob_url, - container_name=config.container_name, - path_prefix=config.base_dir, - ) - case InputType.file: - log.info("using file storage for input") - storage = FilePipelineStorage( - root_dir=str(Path(root_dir) / (config.base_dir or "")) - ) - case _: - log.info("using file storage for input") - storage = FilePipelineStorage( - root_dir=str(Path(root_dir) / (config.base_dir or "")) - ) - if config.file_type in loaders: progress = progress_reporter.child( f"Loading Input ({config.file_type})", transient=False diff --git a/graphrag/index/input/json.py b/graphrag/index/input/json.py index fed19d31..df5b09cb 100644 --- a/graphrag/index/input/json.py +++ b/graphrag/index/input/json.py @@ -22,7 +22,7 @@ async def load_json( storage: PipelineStorage, ) -> pd.DataFrame: """Load json inputs from a directory.""" - log.info("Loading json files from %s", config.base_dir) + log.info("Loading json files from %s", config.storage.base_dir) async def load_file(path: str, group: dict | None) -> pd.DataFrame: if group is None: diff --git a/graphrag/index/input/util.py b/graphrag/index/input/util.py index b8ab4550..7335e9e8 100644 --- a/graphrag/index/input/util.py +++ b/graphrag/index/input/util.py @@ -33,7 +33,7 @@ async def load_files( ) if len(files) == 0: - msg = f"No {config.file_type} files found in {config.base_dir}" + msg = f"No {config.file_type} files found in {config.storage.base_dir}" raise ValueError(msg) files_loaded = [] diff --git a/graphrag/index/run/run_pipeline.py b/graphrag/index/run/run_pipeline.py index 65c41b7e..aea402cc 100644 --- a/graphrag/index/run/run_pipeline.py +++ b/graphrag/index/run/run_pipeline.py @@ -11,16 +11,12 @@ import traceback from collections.abc import AsyncIterable from dataclasses import asdict -import pandas as pd - from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.config.models.graph_rag_config import GraphRagConfig -from graphrag.index.input.factory import create_input from graphrag.index.run.utils import create_run_context from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.pipeline import Pipeline from graphrag.index.typing.pipeline_run_result import PipelineRunResult -from graphrag.index.update.incremental_index import get_delta_docs from graphrag.logger.base import ProgressLogger from graphrag.logger.progress import Progress from graphrag.storage.pipeline_storage import PipelineStorage @@ -40,86 +36,72 @@ async def run_pipeline( """Run all workflows using a simplified pipeline.""" root_dir = config.root_dir - storage = create_storage_from_config(config.output) + input_storage = create_storage_from_config(config.input.storage) + output_storage = create_storage_from_config(config.output) cache = create_cache_from_config(config.cache, root_dir) - dataset = await create_input(config.input, logger, root_dir) - # load existing state in case any workflows are stateful - state_json = await storage.get("context.json") + state_json = await output_storage.get("context.json") state = json.loads(state_json) if state_json else {} if is_update_run: logger.info("Running incremental indexing.") - delta_dataset = await get_delta_docs(dataset, storage) + update_storage = create_storage_from_config(config.update_index_output) + # we use this to store the new subset index, and will merge its content with the previous index + update_timestamp = time.strftime("%Y%m%d-%H%M%S") + timestamped_storage = update_storage.child(update_timestamp) + delta_storage = timestamped_storage.child("delta") + # copy the previous output to a backup folder, so we can replace it with the update + # we'll read from this later when we merge the old and new indexes + previous_storage = timestamped_storage.child("previous") + await _copy_previous_output(output_storage, previous_storage) - # warn on empty delta dataset - if delta_dataset.new_inputs.empty: - warning_msg = "Incremental indexing found no new documents, exiting." - logger.warning(warning_msg) - else: - update_storage = create_storage_from_config(config.update_index_output) - # we use this to store the new subset index, and will merge its content with the previous index - update_timestamp = time.strftime("%Y%m%d-%H%M%S") - timestamped_storage = update_storage.child(update_timestamp) - delta_storage = timestamped_storage.child("delta") - # copy the previous output to a backup folder, so we can replace it with the update - # we'll read from this later when we merge the old and new indexes - previous_storage = timestamped_storage.child("previous") - await _copy_previous_output(storage, previous_storage) + state["update_timestamp"] = update_timestamp - state["update_timestamp"] = update_timestamp - - context = create_run_context( - storage=delta_storage, cache=cache, callbacks=callbacks, state=state - ) - - # Run the pipeline on the new documents - async for table in _run_pipeline( - pipeline=pipeline, - config=config, - dataset=delta_dataset.new_inputs, - logger=logger, - context=context, - ): - yield table - - logger.success("Finished running workflows on new documents.") + context = create_run_context( + input_storage=input_storage, + output_storage=delta_storage, + previous_storage=previous_storage, + cache=cache, + callbacks=callbacks, + state=state, + progress_logger=logger, + ) else: logger.info("Running standard indexing.") context = create_run_context( - storage=storage, cache=cache, callbacks=callbacks, state=state + input_storage=input_storage, + output_storage=output_storage, + cache=cache, + callbacks=callbacks, + state=state, + progress_logger=logger, ) - async for table in _run_pipeline( - pipeline=pipeline, - config=config, - dataset=dataset, - logger=logger, - context=context, - ): - yield table + async for table in _run_pipeline( + pipeline=pipeline, + config=config, + logger=logger, + context=context, + ): + yield table async def _run_pipeline( pipeline: Pipeline, config: GraphRagConfig, - dataset: pd.DataFrame, logger: ProgressLogger, context: PipelineRunContext, ) -> AsyncIterable[PipelineRunResult]: start_time = time.time() - log.info("Final # of rows loaded: %s", len(dataset)) - context.stats.num_documents = len(dataset) - last_workflow = "starting documents" + last_workflow = "" try: await _dump_json(context) - await write_table_to_storage(dataset, "documents", context.storage) for name, workflow_function in pipeline.run(): last_workflow = name @@ -132,8 +114,10 @@ async def _run_pipeline( yield PipelineRunResult( workflow=name, result=result.result, state=context.state, errors=None ) - context.stats.workflows[name] = {"overall": time.time() - work_time} + if result.stop: + logger.info("Halting pipeline at workflow request") + break context.stats.total_runtime = time.time() - start_time await _dump_json(context) @@ -148,10 +132,10 @@ async def _run_pipeline( async def _dump_json(context: PipelineRunContext) -> None: """Dump the stats and context state to the storage.""" - await context.storage.set( + await context.output_storage.set( "stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False) ) - await context.storage.set( + await context.output_storage.set( "context.json", json.dumps(context.state, indent=4, ensure_ascii=False) ) diff --git a/graphrag/index/run/utils.py b/graphrag/index/run/utils.py index a5c23074..809e229b 100644 --- a/graphrag/index/run/utils.py +++ b/graphrag/index/run/utils.py @@ -14,24 +14,31 @@ from graphrag.index.typing.context import PipelineRunContext from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats from graphrag.logger.base import ProgressLogger +from graphrag.logger.null_progress import NullProgressLogger from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage from graphrag.storage.pipeline_storage import PipelineStorage from graphrag.utils.api import create_storage_from_config def create_run_context( - storage: PipelineStorage | None = None, + input_storage: PipelineStorage | None = None, + output_storage: PipelineStorage | None = None, + previous_storage: PipelineStorage | None = None, cache: PipelineCache | None = None, callbacks: WorkflowCallbacks | None = None, + progress_logger: ProgressLogger | None = None, stats: PipelineRunStats | None = None, state: PipelineState | None = None, ) -> PipelineRunContext: """Create the run context for the pipeline.""" return PipelineRunContext( - stats=stats or PipelineRunStats(), + input_storage=input_storage or MemoryPipelineStorage(), + output_storage=output_storage or MemoryPipelineStorage(), + previous_storage=previous_storage or MemoryPipelineStorage(), cache=cache or InMemoryCache(), - storage=storage or MemoryPipelineStorage(), callbacks=callbacks or NoopWorkflowCallbacks(), + progress_logger=progress_logger or NullProgressLogger(), + stats=stats or PipelineRunStats(), state=state or {}, ) diff --git a/graphrag/index/typing/context.py b/graphrag/index/typing/context.py index 95a9a501..9ac0a4e3 100644 --- a/graphrag/index/typing/context.py +++ b/graphrag/index/typing/context.py @@ -10,6 +10,7 @@ from graphrag.cache.pipeline_cache import PipelineCache from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks from graphrag.index.typing.state import PipelineState from graphrag.index.typing.stats import PipelineRunStats +from graphrag.logger.base import ProgressLogger from graphrag.storage.pipeline_storage import PipelineStorage @@ -18,11 +19,17 @@ class PipelineRunContext: """Provides the context for the current pipeline run.""" stats: PipelineRunStats - storage: PipelineStorage + input_storage: PipelineStorage + "Storage for input documents." + output_storage: PipelineStorage "Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider." + previous_storage: PipelineStorage + "Storage for previous pipeline run when running in update mode." cache: PipelineCache "Cache instance for reading previous LLM responses." callbacks: WorkflowCallbacks "Callbacks to be called during the pipeline run." + progress_logger: ProgressLogger + "Progress logger for the pipeline run." state: PipelineState "Arbitrary property bag for runtime state, persistent pre-computes, or experimental features." diff --git a/graphrag/index/typing/stats.py b/graphrag/index/typing/stats.py index bcdaffa1..27177360 100644 --- a/graphrag/index/typing/stats.py +++ b/graphrag/index/typing/stats.py @@ -15,6 +15,8 @@ class PipelineRunStats: num_documents: int = field(default=0) """Number of documents.""" + update_documents: int = field(default=0) + """Number of update documents.""" input_load_time: float = field(default=0) """Float representing the input load time.""" diff --git a/graphrag/index/typing/workflow.py b/graphrag/index/typing/workflow.py index 1d18b60a..89538b04 100644 --- a/graphrag/index/typing/workflow.py +++ b/graphrag/index/typing/workflow.py @@ -17,6 +17,8 @@ class WorkflowFunctionOutput: result: Any | None """The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage.""" + stop: bool = False + """Flag to indicate if the workflow should stop after this function. This should only be used when continuation could cause an unstable failure.""" WorkflowFunction = Callable[ diff --git a/graphrag/index/workflows/__init__.py b/graphrag/index/workflows/__init__.py index 7f38a3cc..5567bace 100644 --- a/graphrag/index/workflows/__init__.py +++ b/graphrag/index/workflows/__init__.py @@ -39,6 +39,12 @@ from .finalize_graph import ( from .generate_text_embeddings import ( run_workflow as run_generate_text_embeddings, ) +from .load_input_documents import ( + run_workflow as run_load_input_documents, +) +from .load_update_documents import ( + run_workflow as run_load_update_documents, +) from .prune_graph import ( run_workflow as run_prune_graph, ) @@ -69,6 +75,8 @@ from .update_text_units import ( # register all of our built-in workflows at once PipelineFactory.register_all({ + "load_input_documents": run_load_input_documents, + "load_update_documents": run_load_update_documents, "create_base_text_units": run_create_base_text_units, "create_communities": run_create_communities, "create_community_reports_text": run_create_community_reports_text, diff --git a/graphrag/index/workflows/create_base_text_units.py b/graphrag/index/workflows/create_base_text_units.py index 30d4f2a7..d25c9a9f 100644 --- a/graphrag/index/workflows/create_base_text_units.py +++ b/graphrag/index/workflows/create_base_text_units.py @@ -25,7 +25,7 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to transform base text_units.""" - documents = await load_table_from_storage("documents", context.storage) + documents = await load_table_from_storage("documents", context.output_storage) chunks = config.chunks @@ -41,7 +41,7 @@ async def run_workflow( chunk_size_includes_metadata=chunks.chunk_size_includes_metadata, ) - await write_table_to_storage(output, "text_units", context.storage) + await write_table_to_storage(output, "text_units", context.output_storage) return WorkflowFunctionOutput(result=output) diff --git a/graphrag/index/workflows/create_communities.py b/graphrag/index/workflows/create_communities.py index b19eb181..b1096f6a 100644 --- a/graphrag/index/workflows/create_communities.py +++ b/graphrag/index/workflows/create_communities.py @@ -24,8 +24,10 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to transform final communities.""" - entities = await load_table_from_storage("entities", context.storage) - relationships = await load_table_from_storage("relationships", context.storage) + entities = await load_table_from_storage("entities", context.output_storage) + relationships = await load_table_from_storage( + "relationships", context.output_storage + ) max_cluster_size = config.cluster_graph.max_cluster_size use_lcc = config.cluster_graph.use_lcc @@ -39,7 +41,7 @@ async def run_workflow( seed=seed, ) - await write_table_to_storage(output, "communities", context.storage) + await write_table_to_storage(output, "communities", context.output_storage) return WorkflowFunctionOutput(result=output) diff --git a/graphrag/index/workflows/create_community_reports.py b/graphrag/index/workflows/create_community_reports.py index 6f20639c..15dc187c 100644 --- a/graphrag/index/workflows/create_community_reports.py +++ b/graphrag/index/workflows/create_community_reports.py @@ -38,14 +38,14 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" - edges = await load_table_from_storage("relationships", context.storage) - entities = await load_table_from_storage("entities", context.storage) - communities = await load_table_from_storage("communities", context.storage) + edges = await load_table_from_storage("relationships", context.output_storage) + entities = await load_table_from_storage("entities", context.output_storage) + communities = await load_table_from_storage("communities", context.output_storage) claims = None if config.extract_claims.enabled and await storage_has_table( - "covariates", context.storage + "covariates", context.output_storage ): - claims = await load_table_from_storage("covariates", context.storage) + claims = await load_table_from_storage("covariates", context.output_storage) community_reports_llm_settings = config.get_language_model_config( config.community_reports.model_id @@ -68,7 +68,7 @@ async def run_workflow( num_threads=num_threads, ) - await write_table_to_storage(output, "community_reports", context.storage) + await write_table_to_storage(output, "community_reports", context.output_storage) return WorkflowFunctionOutput(result=output) diff --git a/graphrag/index/workflows/create_community_reports_text.py b/graphrag/index/workflows/create_community_reports_text.py index c713a6c5..b584c6f5 100644 --- a/graphrag/index/workflows/create_community_reports_text.py +++ b/graphrag/index/workflows/create_community_reports_text.py @@ -37,10 +37,10 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to transform community reports.""" - entities = await load_table_from_storage("entities", context.storage) - communities = await load_table_from_storage("communities", context.storage) + entities = await load_table_from_storage("entities", context.output_storage) + communities = await load_table_from_storage("communities", context.output_storage) - text_units = await load_table_from_storage("text_units", context.storage) + text_units = await load_table_from_storage("text_units", context.output_storage) community_reports_llm_settings = config.get_language_model_config( config.community_reports.model_id @@ -62,7 +62,7 @@ async def run_workflow( num_threads=num_threads, ) - await write_table_to_storage(output, "community_reports", context.storage) + await write_table_to_storage(output, "community_reports", context.output_storage) return WorkflowFunctionOutput(result=output) diff --git a/graphrag/index/workflows/create_final_documents.py b/graphrag/index/workflows/create_final_documents.py index 16ab4532..0f007660 100644 --- a/graphrag/index/workflows/create_final_documents.py +++ b/graphrag/index/workflows/create_final_documents.py @@ -17,12 +17,12 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to transform final documents.""" - documents = await load_table_from_storage("documents", context.storage) - text_units = await load_table_from_storage("text_units", context.storage) + documents = await load_table_from_storage("documents", context.output_storage) + text_units = await load_table_from_storage("text_units", context.output_storage) output = create_final_documents(documents, text_units) - await write_table_to_storage(output, "documents", context.storage) + await write_table_to_storage(output, "documents", context.output_storage) return WorkflowFunctionOutput(result=output) diff --git a/graphrag/index/workflows/create_final_text_units.py b/graphrag/index/workflows/create_final_text_units.py index e28fd1ce..f1ae44ed 100644 --- a/graphrag/index/workflows/create_final_text_units.py +++ b/graphrag/index/workflows/create_final_text_units.py @@ -21,16 +21,18 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to transform the text units.""" - text_units = await load_table_from_storage("text_units", context.storage) - final_entities = await load_table_from_storage("entities", context.storage) + text_units = await load_table_from_storage("text_units", context.output_storage) + final_entities = await load_table_from_storage("entities", context.output_storage) final_relationships = await load_table_from_storage( - "relationships", context.storage + "relationships", context.output_storage ) final_covariates = None if config.extract_claims.enabled and await storage_has_table( - "covariates", context.storage + "covariates", context.output_storage ): - final_covariates = await load_table_from_storage("covariates", context.storage) + final_covariates = await load_table_from_storage( + "covariates", context.output_storage + ) output = create_final_text_units( text_units, @@ -39,7 +41,7 @@ async def run_workflow( final_covariates, ) - await write_table_to_storage(output, "text_units", context.storage) + await write_table_to_storage(output, "text_units", context.output_storage) return WorkflowFunctionOutput(result=output) diff --git a/graphrag/index/workflows/extract_covariates.py b/graphrag/index/workflows/extract_covariates.py index b4124301..c6766850 100644 --- a/graphrag/index/workflows/extract_covariates.py +++ b/graphrag/index/workflows/extract_covariates.py @@ -26,30 +26,32 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to extract and format covariates.""" - text_units = await load_table_from_storage("text_units", context.storage) + output = None + if config.extract_claims.enabled: + text_units = await load_table_from_storage("text_units", context.output_storage) - extract_claims_llm_settings = config.get_language_model_config( - config.extract_claims.model_id - ) - extraction_strategy = config.extract_claims.resolved_strategy( - config.root_dir, extract_claims_llm_settings - ) + extract_claims_llm_settings = config.get_language_model_config( + config.extract_claims.model_id + ) + extraction_strategy = config.extract_claims.resolved_strategy( + config.root_dir, extract_claims_llm_settings + ) - async_mode = extract_claims_llm_settings.async_mode - num_threads = extract_claims_llm_settings.concurrent_requests + async_mode = extract_claims_llm_settings.async_mode + num_threads = extract_claims_llm_settings.concurrent_requests - output = await extract_covariates( - text_units, - context.callbacks, - context.cache, - "claim", - extraction_strategy, - async_mode=async_mode, - entity_types=None, - num_threads=num_threads, - ) + output = await extract_covariates( + text_units, + context.callbacks, + context.cache, + "claim", + extraction_strategy, + async_mode=async_mode, + entity_types=None, + num_threads=num_threads, + ) - await write_table_to_storage(output, "covariates", context.storage) + await write_table_to_storage(output, "covariates", context.output_storage) return WorkflowFunctionOutput(result=output) diff --git a/graphrag/index/workflows/extract_graph.py b/graphrag/index/workflows/extract_graph.py index 84f8647e..c6d259de 100644 --- a/graphrag/index/workflows/extract_graph.py +++ b/graphrag/index/workflows/extract_graph.py @@ -27,7 +27,7 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" - text_units = await load_table_from_storage("text_units", context.storage) + text_units = await load_table_from_storage("text_units", context.output_storage) extract_graph_llm_settings = config.get_language_model_config( config.extract_graph.model_id @@ -55,13 +55,15 @@ async def run_workflow( summarization_num_threads=summarization_llm_settings.concurrent_requests, ) - await write_table_to_storage(entities, "entities", context.storage) - await write_table_to_storage(relationships, "relationships", context.storage) + await write_table_to_storage(entities, "entities", context.output_storage) + await write_table_to_storage(relationships, "relationships", context.output_storage) if config.snapshots.raw_graph: - await write_table_to_storage(raw_entities, "raw_entities", context.storage) await write_table_to_storage( - raw_relationships, "raw_relationships", context.storage + raw_entities, "raw_entities", context.output_storage + ) + await write_table_to_storage( + raw_relationships, "raw_relationships", context.output_storage ) return WorkflowFunctionOutput( diff --git a/graphrag/index/workflows/extract_graph_nlp.py b/graphrag/index/workflows/extract_graph_nlp.py index 397c00f7..5c1d0972 100644 --- a/graphrag/index/workflows/extract_graph_nlp.py +++ b/graphrag/index/workflows/extract_graph_nlp.py @@ -22,7 +22,7 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" - text_units = await load_table_from_storage("text_units", context.storage) + text_units = await load_table_from_storage("text_units", context.output_storage) entities, relationships = await extract_graph_nlp( text_units, @@ -30,8 +30,8 @@ async def run_workflow( extraction_config=config.extract_graph_nlp, ) - await write_table_to_storage(entities, "entities", context.storage) - await write_table_to_storage(relationships, "relationships", context.storage) + await write_table_to_storage(entities, "entities", context.output_storage) + await write_table_to_storage(relationships, "relationships", context.output_storage) return WorkflowFunctionOutput( result={ diff --git a/graphrag/index/workflows/factory.py b/graphrag/index/workflows/factory.py index c73e64b6..6e86540d 100644 --- a/graphrag/index/workflows/factory.py +++ b/graphrag/index/workflows/factory.py @@ -15,6 +15,7 @@ class PipelineFactory: """A factory class for workflow pipelines.""" workflows: ClassVar[dict[str, WorkflowFunction]] = {} + pipelines: ClassVar[dict[str, list[str]]] = {} @classmethod def register(cls, name: str, workflow: WorkflowFunction): @@ -27,61 +28,66 @@ class PipelineFactory: for name, workflow in workflows.items(): cls.register(name, workflow) + @classmethod + def register_pipeline(cls, name: str, workflows: list[str]): + """Register a new pipeline method as a list of workflow names.""" + cls.pipelines[name] = workflows + @classmethod def create_pipeline( cls, config: GraphRagConfig, - method: IndexingMethod = IndexingMethod.Standard, - is_update_run: bool = False, + method: IndexingMethod | str = IndexingMethod.Standard, ) -> Pipeline: """Create a pipeline generator.""" - workflows = _get_workflows_list(config, method, is_update_run) + workflows = config.workflows or cls.pipelines.get(method, []) return Pipeline([(name, cls.workflows[name]) for name in workflows]) -def _get_workflows_list( - config: GraphRagConfig, - method: IndexingMethod = IndexingMethod.Standard, - is_update_run: bool = False, -) -> list[str]: - """Return a list of workflows for the indexing pipeline.""" - update_workflows = [ - "update_final_documents", - "update_entities_relationships", - "update_text_units", - "update_covariates", - "update_communities", - "update_community_reports", - "update_text_embeddings", - "update_clean_state", - ] - if config.workflows: - return config.workflows - - match method: - case IndexingMethod.Standard: - return [ - "create_base_text_units", - "create_final_documents", - "extract_graph", - "finalize_graph", - *(["extract_covariates"] if config.extract_claims.enabled else []), - "create_communities", - "create_final_text_units", - "create_community_reports", - "generate_text_embeddings", - *(update_workflows if is_update_run else []), - ] - case IndexingMethod.Fast: - return [ - "create_base_text_units", - "create_final_documents", - "extract_graph_nlp", - "prune_graph", - "finalize_graph", - "create_communities", - "create_final_text_units", - "create_community_reports_text", - "generate_text_embeddings", - *(update_workflows if is_update_run else []), - ] +# --- Register default implementations --- +_standard_workflows = [ + "create_base_text_units", + "create_final_documents", + "extract_graph", + "finalize_graph", + "extract_covariates", + "create_communities", + "create_final_text_units", + "create_community_reports", + "generate_text_embeddings", +] +_fast_workflows = [ + "create_base_text_units", + "create_final_documents", + "extract_graph_nlp", + "prune_graph", + "finalize_graph", + "create_communities", + "create_final_text_units", + "create_community_reports_text", + "generate_text_embeddings", +] +_update_workflows = [ + "update_final_documents", + "update_entities_relationships", + "update_text_units", + "update_covariates", + "update_communities", + "update_community_reports", + "update_text_embeddings", + "update_clean_state", +] +PipelineFactory.register_pipeline( + IndexingMethod.Standard, ["load_input_documents", *_standard_workflows] +) +PipelineFactory.register_pipeline( + IndexingMethod.Fast, ["load_input_documents", *_fast_workflows] +) +PipelineFactory.register_pipeline( + IndexingMethod.StandardUpdate, + ["load_update_documents", *_standard_workflows, *_update_workflows], +) +PipelineFactory.register_pipeline( + IndexingMethod.FastUpdate, + ["load_update_documents", *_fast_workflows, *_update_workflows], +) diff --git a/graphrag/index/workflows/finalize_graph.py b/graphrag/index/workflows/finalize_graph.py index a5a94ba0..8e79bb39 100644 --- a/graphrag/index/workflows/finalize_graph.py +++ b/graphrag/index/workflows/finalize_graph.py @@ -22,8 +22,10 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" - entities = await load_table_from_storage("entities", context.storage) - relationships = await load_table_from_storage("relationships", context.storage) + entities = await load_table_from_storage("entities", context.output_storage) + relationships = await load_table_from_storage( + "relationships", context.output_storage + ) final_entities, final_relationships = finalize_graph( entities, @@ -33,8 +35,10 @@ async def run_workflow( layout_enabled=config.umap.enabled, ) - await write_table_to_storage(final_entities, "entities", context.storage) - await write_table_to_storage(final_relationships, "relationships", context.storage) + await write_table_to_storage(final_entities, "entities", context.output_storage) + await write_table_to_storage( + final_relationships, "relationships", context.output_storage + ) if config.snapshots.graphml: # todo: extract graphs at each level, and add in meta like descriptions @@ -43,7 +47,7 @@ async def run_workflow( await snapshot_graphml( graph, name="graph", - storage=context.storage, + storage=context.output_storage, ) return WorkflowFunctionOutput( diff --git a/graphrag/index/workflows/generate_text_embeddings.py b/graphrag/index/workflows/generate_text_embeddings.py index f10fcbdb..c19fed96 100644 --- a/graphrag/index/workflows/generate_text_embeddings.py +++ b/graphrag/index/workflows/generate_text_embeddings.py @@ -43,17 +43,19 @@ async def run_workflow( text_units = None entities = None community_reports = None - if await storage_has_table("documents", context.storage): - documents = await load_table_from_storage("documents", context.storage) - if await storage_has_table("relationships", context.storage): - relationships = await load_table_from_storage("relationships", context.storage) - if await storage_has_table("text_units", context.storage): - text_units = await load_table_from_storage("text_units", context.storage) - if await storage_has_table("entities", context.storage): - entities = await load_table_from_storage("entities", context.storage) - if await storage_has_table("community_reports", context.storage): + if await storage_has_table("documents", context.output_storage): + documents = await load_table_from_storage("documents", context.output_storage) + if await storage_has_table("relationships", context.output_storage): + relationships = await load_table_from_storage( + "relationships", context.output_storage + ) + if await storage_has_table("text_units", context.output_storage): + text_units = await load_table_from_storage("text_units", context.output_storage) + if await storage_has_table("entities", context.output_storage): + entities = await load_table_from_storage("entities", context.output_storage) + if await storage_has_table("community_reports", context.output_storage): community_reports = await load_table_from_storage( - "community_reports", context.storage + "community_reports", context.output_storage ) embedded_fields = config.embed_text.names @@ -76,7 +78,7 @@ async def run_workflow( await write_table_to_storage( table, f"embeddings.{name}", - context.storage, + context.output_storage, ) return WorkflowFunctionOutput(result=output) diff --git a/graphrag/index/workflows/load_input_documents.py b/graphrag/index/workflows/load_input_documents.py new file mode 100644 index 00000000..b3d41703 --- /dev/null +++ b/graphrag/index/workflows/load_input_documents.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +import pandas as pd + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.input_config import InputConfig +from graphrag.index.input.factory import create_input +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.logger.base import ProgressLogger +from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import write_table_to_storage + +log = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Load and parse input documents into a standard format.""" + output = await load_input_documents( + config.input, + context.input_storage, + context.progress_logger, + ) + + log.info("Final # of rows loaded: %s", len(output)) + context.stats.num_documents = len(output) + + await write_table_to_storage(output, "documents", context.output_storage) + + return WorkflowFunctionOutput(result=output) + + +async def load_input_documents( + config: InputConfig, storage: PipelineStorage, progress_logger: ProgressLogger +) -> pd.DataFrame: + """Load and parse input documents into a standard format.""" + return await create_input(config, storage, progress_logger) diff --git a/graphrag/index/workflows/load_update_documents.py b/graphrag/index/workflows/load_update_documents.py new file mode 100644 index 00000000..18d1e1b9 --- /dev/null +++ b/graphrag/index/workflows/load_update_documents.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A module containing run_workflow method definition.""" + +import logging + +import pandas as pd + +from graphrag.config.models.graph_rag_config import GraphRagConfig +from graphrag.config.models.input_config import InputConfig +from graphrag.index.input.factory import create_input +from graphrag.index.typing.context import PipelineRunContext +from graphrag.index.typing.workflow import WorkflowFunctionOutput +from graphrag.index.update.incremental_index import get_delta_docs +from graphrag.logger.base import ProgressLogger +from graphrag.storage.pipeline_storage import PipelineStorage +from graphrag.utils.storage import write_table_to_storage + +log = logging.getLogger(__name__) + + +async def run_workflow( + config: GraphRagConfig, + context: PipelineRunContext, +) -> WorkflowFunctionOutput: + """Load and parse update-only input documents into a standard format.""" + output = await load_update_documents( + config.input, + context.input_storage, + context.previous_storage, + context.progress_logger, + ) + + log.info("Final # of update rows loaded: %s", len(output)) + context.stats.update_documents = len(output) + + if len(output) == 0: + log.warning("No new update documents found.") + context.progress_logger.warning("No new update documents found.") + return WorkflowFunctionOutput(result=None, stop=True) + + await write_table_to_storage(output, "documents", context.output_storage) + + return WorkflowFunctionOutput(result=output) + + +async def load_update_documents( + config: InputConfig, + input_storage: PipelineStorage, + previous_storage: PipelineStorage, + progress_logger: ProgressLogger, +) -> pd.DataFrame: + """Load and parse update-only input documents into a standard format.""" + input_documents = await create_input(config, input_storage, progress_logger) + # previous storage is the output of the previous run + # we'll use this to diff the input from the prior + delta_documents = await get_delta_docs(input_documents, previous_storage) + return delta_documents.new_inputs diff --git a/graphrag/index/workflows/prune_graph.py b/graphrag/index/workflows/prune_graph.py index c4abad7e..52987a8d 100644 --- a/graphrag/index/workflows/prune_graph.py +++ b/graphrag/index/workflows/prune_graph.py @@ -20,8 +20,10 @@ async def run_workflow( context: PipelineRunContext, ) -> WorkflowFunctionOutput: """All the steps to create the base entity graph.""" - entities = await load_table_from_storage("entities", context.storage) - relationships = await load_table_from_storage("relationships", context.storage) + entities = await load_table_from_storage("entities", context.output_storage) + relationships = await load_table_from_storage( + "relationships", context.output_storage + ) pruned_entities, pruned_relationships = prune_graph( entities, @@ -29,8 +31,10 @@ async def run_workflow( pruning_config=config.prune_graph, ) - await write_table_to_storage(pruned_entities, "entities", context.storage) - await write_table_to_storage(pruned_relationships, "relationships", context.storage) + await write_table_to_storage(pruned_entities, "entities", context.output_storage) + await write_table_to_storage( + pruned_relationships, "relationships", context.output_storage + ) return WorkflowFunctionOutput( result={ diff --git a/graphrag/prompt_tune/loader/input.py b/graphrag/prompt_tune/loader/input.py index fa49ebee..71f802d0 100644 --- a/graphrag/prompt_tune/loader/input.py +++ b/graphrag/prompt_tune/loader/input.py @@ -21,6 +21,7 @@ from graphrag.prompt_tune.defaults import ( K, ) from graphrag.prompt_tune.types import DocSelectionType +from graphrag.utils.api import create_storage_from_config def _sample_chunks_from_embeddings( @@ -37,7 +38,6 @@ def _sample_chunks_from_embeddings( async def load_docs_in_chunks( - root: str, config: GraphRagConfig, select_method: DocSelectionType, limit: int, @@ -51,7 +51,8 @@ async def load_docs_in_chunks( embeddings_llm_settings = config.get_language_model_config( config.embed_text.model_id ) - dataset = await create_input(config.input, logger, root) + input_storage = create_storage_from_config(config.input.storage) + dataset = await create_input(config.input, input_storage, logger) chunk_config = config.chunks chunks_df = create_base_text_units( documents=dataset, diff --git a/graphrag/storage/factory.py b/graphrag/storage/factory.py index 8a6e0df4..d9243fb7 100644 --- a/graphrag/storage/factory.py +++ b/graphrag/storage/factory.py @@ -7,7 +7,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, ClassVar -from graphrag.config.enums import OutputType +from graphrag.config.enums import StorageType from graphrag.storage.blob_pipeline_storage import create_blob_storage from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage from graphrag.storage.file_pipeline_storage import create_file_storage @@ -35,17 +35,17 @@ class StorageFactory: @classmethod def create_storage( - cls, storage_type: OutputType | str, kwargs: dict + cls, storage_type: StorageType | str, kwargs: dict ) -> PipelineStorage: """Create or get a storage object from the provided type.""" match storage_type: - case OutputType.blob: + case StorageType.blob: return create_blob_storage(**kwargs) - case OutputType.cosmosdb: + case StorageType.cosmosdb: return create_cosmosdb_storage(**kwargs) - case OutputType.file: + case StorageType.file: return create_file_storage(**kwargs) - case OutputType.memory: + case StorageType.memory: return MemoryPipelineStorage() case _: if storage_type in cls.storage_types: diff --git a/graphrag/utils/api.py b/graphrag/utils/api.py index 9b69ef97..2512fb0c 100644 --- a/graphrag/utils/api.py +++ b/graphrag/utils/api.py @@ -10,7 +10,7 @@ from graphrag.cache.factory import CacheFactory from graphrag.cache.pipeline_cache import PipelineCache from graphrag.config.embeddings import create_collection_name from graphrag.config.models.cache_config import CacheConfig -from graphrag.config.models.output_config import OutputConfig +from graphrag.config.models.storage_config import StorageConfig from graphrag.data_model.types import TextEmbedder from graphrag.storage.factory import StorageFactory from graphrag.storage.pipeline_storage import PipelineStorage @@ -238,7 +238,7 @@ def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None: return None -def create_storage_from_config(output: OutputConfig) -> PipelineStorage: +def create_storage_from_config(output: StorageConfig) -> PipelineStorage: """Create a storage object from the config.""" storage_config = output.model_dump() return StorageFactory().create_storage( diff --git a/tests/fixtures/azure/settings.yml b/tests/fixtures/azure/settings.yml index 04d0e7aa..80ba02e5 100644 --- a/tests/fixtures/azure/settings.yml +++ b/tests/fixtures/azure/settings.yml @@ -9,11 +9,13 @@ vector_store: container_name: "azure_ci" input: - type: blob + storage: + type: blob + connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING} + container_name: azurefixture + base_dir: input file_type: text - connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING} - container_name: azurefixture - base_dir: input + cache: type: blob @@ -21,7 +23,7 @@ cache: container_name: cicache base_dir: cache_azure_ai -storage: +output: type: blob connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING} container_name: azurefixture diff --git a/tests/fixtures/min-csv/config.json b/tests/fixtures/min-csv/config.json index de496593..8ef26786 100644 --- a/tests/fixtures/min-csv/config.json +++ b/tests/fixtures/min-csv/config.json @@ -2,9 +2,15 @@ "input_path": "./tests/fixtures/min-csv", "input_file_type": "text", "workflow_config": { + "load_input_documents": { + "max_runtime": 30 + }, "create_base_text_units": { "max_runtime": 30 }, + "extract_covariates": { + "max_runtime": 10 + }, "extract_graph": { "max_runtime": 500 }, diff --git a/tests/fixtures/text/config.json b/tests/fixtures/text/config.json index 50818ad0..cc6e632e 100644 --- a/tests/fixtures/text/config.json +++ b/tests/fixtures/text/config.json @@ -2,6 +2,9 @@ "input_path": "./tests/fixtures/text", "input_file_type": "text", "workflow_config": { + "load_input_documents": { + "max_runtime": 30 + }, "create_base_text_units": { "max_runtime": 30 }, diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 5be62e90..81e1781d 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -9,7 +9,7 @@ import sys import pytest -from graphrag.config.enums import OutputType +from graphrag.config.enums import StorageType from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage from graphrag.storage.factory import StorageFactory @@ -29,7 +29,7 @@ def test_create_blob_storage(): "base_dir": "testbasedir", "container_name": "testcontainer", } - storage = StorageFactory.create_storage(OutputType.blob, kwargs) + storage = StorageFactory.create_storage(StorageType.blob, kwargs) assert isinstance(storage, BlobPipelineStorage) @@ -44,19 +44,19 @@ def test_create_cosmosdb_storage(): "base_dir": "testdatabase", "container_name": "testcontainer", } - storage = StorageFactory.create_storage(OutputType.cosmosdb, kwargs) + storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs) assert isinstance(storage, CosmosDBPipelineStorage) def test_create_file_storage(): kwargs = {"type": "file", "base_dir": "/tmp/teststorage"} - storage = StorageFactory.create_storage(OutputType.file, kwargs) + storage = StorageFactory.create_storage(StorageType.file, kwargs) assert isinstance(storage, FilePipelineStorage) def test_create_memory_storage(): kwargs = {"type": "memory"} - storage = StorageFactory.create_storage(OutputType.memory, kwargs) + storage = StorageFactory.create_storage(StorageType.memory, kwargs) assert isinstance(storage, MemoryPipelineStorage) diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index 4aea91a0..2fa5b141 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -24,10 +24,10 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config.models.input_config import InputConfig from graphrag.config.models.language_model_config import LanguageModelConfig from graphrag.config.models.local_search_config import LocalSearchConfig -from graphrag.config.models.output_config import OutputConfig from graphrag.config.models.prune_graph_config import PruneGraphConfig from graphrag.config.models.reporting_config import ReportingConfig from graphrag.config.models.snapshots_config import SnapshotsConfig +from graphrag.config.models.storage_config import StorageConfig from graphrag.config.models.summarize_descriptions_config import ( SummarizeDescriptionsConfig, ) @@ -134,7 +134,7 @@ def assert_reporting_configs( assert actual.storage_account_blob_url == expected.storage_account_blob_url -def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None: +def assert_output_configs(actual: StorageConfig, expected: StorageConfig) -> None: assert expected.type == actual.type assert expected.base_dir == actual.base_dir assert expected.connection_string == actual.connection_string @@ -143,7 +143,9 @@ def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None: assert expected.cosmosdb_account_url == actual.cosmosdb_account_url -def assert_update_output_configs(actual: OutputConfig, expected: OutputConfig) -> None: +def assert_update_output_configs( + actual: StorageConfig, expected: StorageConfig +) -> None: assert expected.type == actual.type assert expected.base_dir == actual.base_dir assert expected.connection_string == actual.connection_string @@ -162,12 +164,15 @@ def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None: - assert actual.type == expected.type + assert actual.storage.type == expected.storage.type assert actual.file_type == expected.file_type - assert actual.base_dir == expected.base_dir - assert actual.connection_string == expected.connection_string - assert actual.storage_account_blob_url == expected.storage_account_blob_url - assert actual.container_name == expected.container_name + assert actual.storage.base_dir == expected.storage.base_dir + assert actual.storage.connection_string == expected.storage.connection_string + assert ( + actual.storage.storage_account_blob_url + == expected.storage.storage_account_blob_url + ) + assert actual.storage.container_name == expected.storage.container_name assert actual.encoding == expected.encoding assert actual.file_pattern == expected.file_pattern assert actual.file_filter == expected.file_filter diff --git a/tests/unit/indexing/input/test_csv_loader.py b/tests/unit/indexing/input/test_csv_loader.py index cc6b0e6f..965f8366 100644 --- a/tests/unit/indexing/input/test_csv_loader.py +++ b/tests/unit/indexing/input/test_csv_loader.py @@ -1,56 +1,66 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.enums import InputFileType, InputType +from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig +from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import create_input +from graphrag.utils.api import create_storage_from_config async def test_csv_loader_one_file(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-csv", + ), file_type=InputFileType.csv, file_pattern=".*\\.csv$", - base_dir="tests/unit/indexing/input/data/one-csv", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (2, 4) assert documents["title"].iloc[0] == "input.csv" async def test_csv_loader_one_file_with_title(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-csv", + ), file_type=InputFileType.csv, file_pattern=".*\\.csv$", - base_dir="tests/unit/indexing/input/data/one-csv", title_column="title", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (2, 4) assert documents["title"].iloc[0] == "Hello" async def test_csv_loader_one_file_with_metadata(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-csv", + ), file_type=InputFileType.csv, file_pattern=".*\\.csv$", - base_dir="tests/unit/indexing/input/data/one-csv", title_column="title", metadata=["title"], ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (2, 5) assert documents["metadata"][0] == {"title": "Hello"} async def test_csv_loader_multiple_files(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/multiple-csvs", + ), file_type=InputFileType.csv, file_pattern=".*\\.csv$", - base_dir="tests/unit/indexing/input/data/multiple-csvs", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (4, 4) diff --git a/tests/unit/indexing/input/test_json_loader.py b/tests/unit/indexing/input/test_json_loader.py index 365ef9b3..c97d38d4 100644 --- a/tests/unit/indexing/input/test_json_loader.py +++ b/tests/unit/indexing/input/test_json_loader.py @@ -1,31 +1,37 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.enums import InputFileType, InputType +from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig +from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import create_input +from graphrag.utils.api import create_storage_from_config async def test_json_loader_one_file_one_object(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-json-one-object", + ), file_type=InputFileType.json, file_pattern=".*\\.json$", - base_dir="tests/unit/indexing/input/data/one-json-one-object", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (1, 4) assert documents["title"].iloc[0] == "input.json" async def test_json_loader_one_file_multiple_objects(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-json-multiple-objects", + ), file_type=InputFileType.json, file_pattern=".*\\.json$", - base_dir="tests/unit/indexing/input/data/one-json-multiple-objects", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) print(documents) assert documents.shape == (3, 4) assert documents["title"].iloc[0] == "input.json" @@ -33,37 +39,43 @@ async def test_json_loader_one_file_multiple_objects(): async def test_json_loader_one_file_with_title(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-json-one-object", + ), file_type=InputFileType.json, file_pattern=".*\\.json$", - base_dir="tests/unit/indexing/input/data/one-json-one-object", title_column="title", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (1, 4) assert documents["title"].iloc[0] == "Hello" async def test_json_loader_one_file_with_metadata(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-json-one-object", + ), file_type=InputFileType.json, file_pattern=".*\\.json$", - base_dir="tests/unit/indexing/input/data/one-json-one-object", title_column="title", metadata=["title"], ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (1, 5) assert documents["metadata"][0] == {"title": "Hello"} async def test_json_loader_multiple_files(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/multiple-jsons", + ), file_type=InputFileType.json, file_pattern=".*\\.json$", - base_dir="tests/unit/indexing/input/data/multiple-jsons", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (4, 4) diff --git a/tests/unit/indexing/input/test_txt_loader.py b/tests/unit/indexing/input/test_txt_loader.py index 05987f4b..6b82a408 100644 --- a/tests/unit/indexing/input/test_txt_loader.py +++ b/tests/unit/indexing/input/test_txt_loader.py @@ -1,32 +1,38 @@ # Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License -from graphrag.config.enums import InputFileType, InputType +from graphrag.config.enums import InputFileType from graphrag.config.models.input_config import InputConfig +from graphrag.config.models.storage_config import StorageConfig from graphrag.index.input.factory import create_input +from graphrag.utils.api import create_storage_from_config async def test_txt_loader_one_file(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-txt", + ), file_type=InputFileType.text, file_pattern=".*\\.txt$", - base_dir="tests/unit/indexing/input/data/one-txt", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (1, 4) assert documents["title"].iloc[0] == "input.txt" async def test_txt_loader_one_file_with_metadata(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/one-txt", + ), file_type=InputFileType.text, file_pattern=".*\\.txt$", - base_dir="tests/unit/indexing/input/data/one-txt", metadata=["title"], ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (1, 5) # unlike csv, we cannot set the title to anything other than the filename assert documents["metadata"][0] == {"title": "input.txt"} @@ -34,10 +40,12 @@ async def test_txt_loader_one_file_with_metadata(): async def test_txt_loader_multiple_files(): config = InputConfig( - type=InputType.file, + storage=StorageConfig( + base_dir="tests/unit/indexing/input/data/multiple-txts", + ), file_type=InputFileType.text, file_pattern=".*\\.txt$", - base_dir="tests/unit/indexing/input/data/multiple-txts", ) - documents = await create_input(config=config) + storage = create_storage_from_config(config.storage) + documents = await create_input(config=config, storage=storage) assert documents.shape == (2, 4) diff --git a/tests/verbs/test_create_base_text_units.py b/tests/verbs/test_create_base_text_units.py index 73f55c27..ea34ae8b 100644 --- a/tests/verbs/test_create_base_text_units.py +++ b/tests/verbs/test_create_base_text_units.py @@ -23,7 +23,7 @@ async def test_create_base_text_units(): await run_workflow(config, context) - actual = await load_table_from_storage("text_units", context.storage) + actual = await load_table_from_storage("text_units", context.output_storage) compare_outputs(actual, expected, columns=["text", "document_ids", "n_tokens"]) @@ -43,7 +43,7 @@ async def test_create_base_text_units_metadata(): await run_workflow(config, context) - actual = await load_table_from_storage("text_units", context.storage) + actual = await load_table_from_storage("text_units", context.output_storage) compare_outputs(actual, expected) @@ -63,6 +63,6 @@ async def test_create_base_text_units_metadata_included_in_chunk(): await run_workflow(config, context) - actual = await load_table_from_storage("text_units", context.storage) + actual = await load_table_from_storage("text_units", context.output_storage) # only check the columns from the base workflow - our expected table is the final and will have more compare_outputs(actual, expected, columns=["text", "document_ids", "n_tokens"]) diff --git a/tests/verbs/test_create_communities.py b/tests/verbs/test_create_communities.py index 15d646a6..1f51667c 100644 --- a/tests/verbs/test_create_communities.py +++ b/tests/verbs/test_create_communities.py @@ -33,7 +33,7 @@ async def test_create_communities(): context, ) - actual = await load_table_from_storage("communities", context.storage) + actual = await load_table_from_storage("communities", context.output_storage) columns = list(expected.columns.values) # don't compare period since it is created with the current date each time diff --git a/tests/verbs/test_create_community_reports.py b/tests/verbs/test_create_community_reports.py index be8c1dbe..56fe4a62 100644 --- a/tests/verbs/test_create_community_reports.py +++ b/tests/verbs/test_create_community_reports.py @@ -66,7 +66,7 @@ async def test_create_community_reports(): await run_workflow(config, context) - actual = await load_table_from_storage("community_reports", context.storage) + actual = await load_table_from_storage("community_reports", context.output_storage) assert len(actual.columns) == len(expected.columns) diff --git a/tests/verbs/test_create_final_documents.py b/tests/verbs/test_create_final_documents.py index a5cfe025..89ff2d7c 100644 --- a/tests/verbs/test_create_final_documents.py +++ b/tests/verbs/test_create_final_documents.py @@ -28,7 +28,7 @@ async def test_create_final_documents(): await run_workflow(config, context) - actual = await load_table_from_storage("documents", context.storage) + actual = await load_table_from_storage("documents", context.output_storage) compare_outputs(actual, expected) @@ -47,11 +47,11 @@ async def test_create_final_documents_with_metadata_column(): # simulate the metadata construction during initial input loading await update_document_metadata(config.input.metadata, context) - expected = await load_table_from_storage("documents", context.storage) + expected = await load_table_from_storage("documents", context.output_storage) await run_workflow(config, context) - actual = await load_table_from_storage("documents", context.storage) + actual = await load_table_from_storage("documents", context.output_storage) compare_outputs(actual, expected) diff --git a/tests/verbs/test_create_final_text_units.py b/tests/verbs/test_create_final_text_units.py index 7b20a13a..979f48d5 100644 --- a/tests/verbs/test_create_final_text_units.py +++ b/tests/verbs/test_create_final_text_units.py @@ -33,7 +33,7 @@ async def test_create_final_text_units(): await run_workflow(config, context) - actual = await load_table_from_storage("text_units", context.storage) + actual = await load_table_from_storage("text_units", context.output_storage) for column in TEXT_UNITS_FINAL_COLUMNS: assert column in actual.columns diff --git a/tests/verbs/test_extract_covariates.py b/tests/verbs/test_extract_covariates.py index bad87d41..7e38f25a 100644 --- a/tests/verbs/test_extract_covariates.py +++ b/tests/verbs/test_extract_covariates.py @@ -37,6 +37,7 @@ async def test_extract_covariates(): ).model_dump() llm_settings["type"] = ModelType.MockChat llm_settings["responses"] = MOCK_LLM_RESPONSES + config.extract_claims.enabled = True config.extract_claims.strategy = { "type": "graph_intelligence", "llm": llm_settings, @@ -45,7 +46,7 @@ async def test_extract_covariates(): await run_workflow(config, context) - actual = await load_table_from_storage("covariates", context.storage) + actual = await load_table_from_storage("covariates", context.output_storage) for column in COVARIATES_FINAL_COLUMNS: assert column in actual.columns diff --git a/tests/verbs/test_extract_graph.py b/tests/verbs/test_extract_graph.py index 618b8430..145d161d 100644 --- a/tests/verbs/test_extract_graph.py +++ b/tests/verbs/test_extract_graph.py @@ -63,8 +63,10 @@ async def test_extract_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.storage) - edges_actual = await load_table_from_storage("relationships", context.storage) + nodes_actual = await load_table_from_storage("entities", context.output_storage) + edges_actual = await load_table_from_storage( + "relationships", context.output_storage + ) assert len(nodes_actual.columns) == 5 assert len(edges_actual.columns) == 5 diff --git a/tests/verbs/test_extract_graph_nlp.py b/tests/verbs/test_extract_graph_nlp.py index 8ec14df2..8c9367ad 100644 --- a/tests/verbs/test_extract_graph_nlp.py +++ b/tests/verbs/test_extract_graph_nlp.py @@ -22,8 +22,10 @@ async def test_extract_graph_nlp(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.storage) - edges_actual = await load_table_from_storage("relationships", context.storage) + nodes_actual = await load_table_from_storage("entities", context.output_storage) + edges_actual = await load_table_from_storage( + "relationships", context.output_storage + ) # this will be the raw count of entities and edges with no pruning # with NLP it is deterministic, so we can assert exact row counts diff --git a/tests/verbs/test_finalize_graph.py b/tests/verbs/test_finalize_graph.py index 7ff25ca3..b70ba8dd 100644 --- a/tests/verbs/test_finalize_graph.py +++ b/tests/verbs/test_finalize_graph.py @@ -25,8 +25,10 @@ async def test_finalize_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.storage) - edges_actual = await load_table_from_storage("relationships", context.storage) + nodes_actual = await load_table_from_storage("entities", context.output_storage) + edges_actual = await load_table_from_storage( + "relationships", context.output_storage + ) assert len(nodes_actual) == 291 assert len(edges_actual) == 452 @@ -51,8 +53,10 @@ async def test_finalize_graph_umap(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.storage) - edges_actual = await load_table_from_storage("relationships", context.storage) + nodes_actual = await load_table_from_storage("entities", context.output_storage) + edges_actual = await load_table_from_storage( + "relationships", context.output_storage + ) assert len(nodes_actual) == 291 assert len(edges_actual) == 452 @@ -75,8 +79,8 @@ async def _prep_tables(): # edit the tables to eliminate final fields that wouldn't be on the inputs entities = load_test_table("entities") entities.drop(columns=["x", "y", "degree"], inplace=True) - await write_table_to_storage(entities, "entities", context.storage) + await write_table_to_storage(entities, "entities", context.output_storage) relationships = load_test_table("relationships") relationships.drop(columns=["combined_degree"], inplace=True) - await write_table_to_storage(relationships, "relationships", context.storage) + await write_table_to_storage(relationships, "relationships", context.output_storage) return context diff --git a/tests/verbs/test_generate_text_embeddings.py b/tests/verbs/test_generate_text_embeddings.py index 788ddfc9..b0e47d16 100644 --- a/tests/verbs/test_generate_text_embeddings.py +++ b/tests/verbs/test_generate_text_embeddings.py @@ -44,14 +44,14 @@ async def test_generate_text_embeddings(): await run_workflow(config, context) - parquet_files = context.storage.keys() + parquet_files = context.output_storage.keys() for field in all_embeddings: assert f"embeddings.{field}.parquet" in parquet_files # entity description should always be here, let's assert its format entity_description_embeddings = await load_table_from_storage( - "embeddings.entity.description", context.storage + "embeddings.entity.description", context.output_storage ) assert len(entity_description_embeddings.columns) == 2 @@ -60,7 +60,7 @@ async def test_generate_text_embeddings(): # every other embedding is optional but we've turned them all on, so check a random one document_text_embeddings = await load_table_from_storage( - "embeddings.document.text", context.storage + "embeddings.document.text", context.output_storage ) assert len(document_text_embeddings.columns) == 2 diff --git a/tests/verbs/test_prune_graph.py b/tests/verbs/test_prune_graph.py index b30546bc..6ed00019 100644 --- a/tests/verbs/test_prune_graph.py +++ b/tests/verbs/test_prune_graph.py @@ -26,6 +26,6 @@ async def test_prune_graph(): await run_workflow(config, context) - nodes_actual = await load_table_from_storage("entities", context.storage) + nodes_actual = await load_table_from_storage("entities", context.output_storage) assert len(nodes_actual) == 21 diff --git a/tests/verbs/util.py b/tests/verbs/util.py index acbd44d8..8d342b47 100644 --- a/tests/verbs/util.py +++ b/tests/verbs/util.py @@ -38,12 +38,12 @@ async def create_test_context(storage: list[str] | None = None) -> PipelineRunCo # always set the input docs, but since our stored table is final, drop what wouldn't be in the original source input input = load_test_table("documents") input.drop(columns=["text_unit_ids"], inplace=True) - await write_table_to_storage(input, "documents", context.storage) + await write_table_to_storage(input, "documents", context.output_storage) if storage: for name in storage: table = load_test_table(name) - await write_table_to_storage(table, name, context.storage) + await write_table_to_storage(table, name, context.output_storage) return context @@ -86,8 +86,8 @@ def compare_outputs( async def update_document_metadata(metadata: list[str], context: PipelineRunContext): """Takes the default documents and adds the configured metadata columns for later parsing by the text units and final documents workflows.""" - documents = await load_table_from_storage("documents", context.storage) + documents = await load_table_from_storage("documents", context.output_storage) documents["metadata"] = documents[metadata].apply(lambda row: row.to_dict(), axis=1) await write_table_to_storage( - documents, "documents", context.storage + documents, "documents", context.output_storage ) # write to the runtime context storage only