Pipeline registration (#1940)

* Move covariate run conditional

* All pipeline registration

* Fix method name construction

* Rename context storage -> output_storage

* Rename OutputConfig as generic StorageConfig

* Reuse Storage model under InputConfig

* Move input storage creation out of document loading

* Move document loading into workflows

* Semver

* Fix smoke test config for new workflows

* Fix unit tests

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Nathan Evans 2025-06-12 16:14:39 -07:00 committed by GitHub
parent 17e431cf42
commit 1df89727c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 602 additions and 424 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Allow injection of custom pipelines."
}

View File

@ -27,7 +27,7 @@ log = logging.getLogger(__name__)
async def build_index(
config: GraphRagConfig,
method: IndexingMethod = IndexingMethod.Standard,
method: IndexingMethod | str = IndexingMethod.Standard,
is_update_run: bool = False,
memory_profile: bool = False,
callbacks: list[WorkflowCallbacks] | None = None,
@ -65,7 +65,9 @@ async def build_index(
if memory_profile:
log.warning("New pipeline does not yet support memory profiling.")
pipeline = PipelineFactory.create_pipeline(config, method, is_update_run)
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
method = _get_method(method, is_update_run)
pipeline = PipelineFactory.create_pipeline(config, method)
workflow_callbacks.pipeline_start(pipeline.names())
@ -90,3 +92,8 @@ async def build_index(
def register_workflow_function(name: str, workflow: WorkflowFunction):
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
PipelineFactory.register(name, workflow)
def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str:
m = method.value if isinstance(method, IndexingMethod) else method
return f"{m}-update" if is_update_run else m

View File

@ -52,7 +52,6 @@ from graphrag.prompt_tune.types import DocSelectionType
async def generate_indexing_prompts(
config: GraphRagConfig,
logger: ProgressLogger,
root: str,
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
overlap: Annotated[
int, annotated_types.Gt(-1)
@ -93,7 +92,6 @@ async def generate_indexing_prompts(
# Retrieve documents
logger.info("Chunking documents...")
doc_list = await load_docs_in_chunks(
root=root,
config=config,
limit=limit,
select_method=selection_method,

View File

@ -80,7 +80,6 @@ def index_cli(
cli_overrides["reporting.base_dir"] = str(output_dir)
cli_overrides["update_index_output.base_dir"] = str(output_dir)
config = load_config(root_dir, config_filepath, cli_overrides)
_run_index(
config=config,
method=method,

View File

@ -86,7 +86,6 @@ async def prompt_tune(
prompts = await api.generate_indexing_prompts(
config=graph_config,
root=str(root_path),
logger=progress_logger,
chunk_size=chunk_size,
overlap=overlap,

View File

@ -14,11 +14,10 @@ from graphrag.config.enums import (
CacheType,
ChunkStrategyType,
InputFileType,
InputType,
ModelType,
NounPhraseExtractorType,
OutputType,
ReportingType,
StorageType,
)
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
EN_STOP_WORDS,
@ -234,16 +233,31 @@ class GlobalSearchDefaults:
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
@dataclass
class StorageDefaults:
"""Default values for storage."""
type = StorageType.file
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
connection_string: None = None
container_name: None = None
storage_account_blob_url: None = None
cosmosdb_account_url: None = None
@dataclass
class InputStorageDefaults(StorageDefaults):
"""Default values for input storage."""
base_dir: str = "input"
@dataclass
class InputDefaults:
"""Default values for input."""
type = InputType.file
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
file_type = InputFileType.text
base_dir: str = "input"
connection_string: None = None
storage_account_blob_url: None = None
container_name: None = None
encoding: str = "utf-8"
file_pattern: str = ""
file_filter: None = None
@ -301,15 +315,10 @@ class LocalSearchDefaults:
@dataclass
class OutputDefaults:
class OutputDefaults(StorageDefaults):
"""Default values for output."""
type = OutputType.file
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
connection_string: None = None
container_name: None = None
storage_account_blob_url: None = None
cosmosdb_account_url: None = None
@dataclass
@ -364,14 +373,10 @@ class UmapDefaults:
@dataclass
class UpdateIndexOutputDefaults:
class UpdateIndexOutputDefaults(StorageDefaults):
"""Default values for update index output."""
type = OutputType.file
base_dir: str = "update_output"
connection_string: None = None
container_name: None = None
storage_account_blob_url: None = None
@dataclass
@ -395,6 +400,7 @@ class GraphRagConfigDefaults:
root_dir: str = ""
models: dict = field(default_factory=dict)
reporting: ReportingDefaults = field(default_factory=ReportingDefaults)
storage: StorageDefaults = field(default_factory=StorageDefaults)
output: OutputDefaults = field(default_factory=OutputDefaults)
outputs: None = None
update_index_output: UpdateIndexOutputDefaults = field(

View File

@ -42,20 +42,7 @@ class InputFileType(str, Enum):
return f'"{self.value}"'
class InputType(str, Enum):
"""The input type for the pipeline."""
file = "file"
"""The file storage type."""
blob = "blob"
"""The blob storage type."""
def __repr__(self):
"""Get a string representation."""
return f'"{self.value}"'
class OutputType(str, Enum):
class StorageType(str, Enum):
"""The output type for the pipeline."""
file = "file"
@ -152,6 +139,10 @@ class IndexingMethod(str, Enum):
"""Traditional GraphRAG indexing, with all graph construction and summarization performed by a language model."""
Fast = "fast"
"""Fast indexing, using NLP for graph construction and language model for summarization."""
StandardUpdate = "standard-update"
"""Incremental update with standard indexing."""
FastUpdate = "fast-update"
"""Incremental update with fast indexing."""
class NounPhraseExtractorType(str, Enum):

View File

@ -58,9 +58,11 @@ models:
### Input settings ###
input:
type: {graphrag_config_defaults.input.type.value} # or blob
storage:
type: {graphrag_config_defaults.input.storage.type.value} # or blob
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
base_dir: "{graphrag_config_defaults.input.base_dir}"
chunks:
size: {graphrag_config_defaults.chunks.size}

View File

@ -26,10 +26,10 @@ from graphrag.config.models.global_search_config import GlobalSearchConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.local_search_config import LocalSearchConfig
from graphrag.config.models.output_config import OutputConfig
from graphrag.config.models.prune_graph_config import PruneGraphConfig
from graphrag.config.models.reporting_config import ReportingConfig
from graphrag.config.models.snapshots_config import SnapshotsConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.summarize_descriptions_config import (
SummarizeDescriptionsConfig,
)
@ -102,21 +102,31 @@ class GraphRagConfig(BaseModel):
else:
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"
def _validate_input_base_dir(self) -> None:
"""Validate the input base directory."""
if self.input.storage.type == defs.StorageType.file:
if self.input.storage.base_dir.strip() == "":
msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration."
raise ValueError(msg)
self.input.storage.base_dir = str(
(Path(self.root_dir) / self.input.storage.base_dir).resolve()
)
chunks: ChunkingConfig = Field(
description="The chunking configuration to use.",
default=ChunkingConfig(),
)
"""The chunking configuration to use."""
output: OutputConfig = Field(
output: StorageConfig = Field(
description="The output configuration.",
default=OutputConfig(),
default=StorageConfig(),
)
"""The output configuration."""
def _validate_output_base_dir(self) -> None:
"""Validate the output base directory."""
if self.output.type == defs.OutputType.file:
if self.output.type == defs.StorageType.file:
if self.output.base_dir.strip() == "":
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
raise ValueError(msg)
@ -124,7 +134,7 @@ class GraphRagConfig(BaseModel):
(Path(self.root_dir) / self.output.base_dir).resolve()
)
outputs: dict[str, OutputConfig] | None = Field(
outputs: dict[str, StorageConfig] | None = Field(
description="A list of output configurations used for multi-index query.",
default=graphrag_config_defaults.outputs,
)
@ -133,7 +143,7 @@ class GraphRagConfig(BaseModel):
"""Validate the outputs dict base directories."""
if self.outputs:
for output in self.outputs.values():
if output.type == defs.OutputType.file:
if output.type == defs.StorageType.file:
if output.base_dir.strip() == "":
msg = "Output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
raise ValueError(msg)
@ -141,10 +151,9 @@ class GraphRagConfig(BaseModel):
(Path(self.root_dir) / output.base_dir).resolve()
)
update_index_output: OutputConfig = Field(
update_index_output: StorageConfig = Field(
description="The output configuration for the updated index.",
default=OutputConfig(
type=graphrag_config_defaults.update_index_output.type,
default=StorageConfig(
base_dir=graphrag_config_defaults.update_index_output.base_dir,
),
)
@ -152,7 +161,7 @@ class GraphRagConfig(BaseModel):
def _validate_update_index_output_base_dir(self) -> None:
"""Validate the update index output base directory."""
if self.update_index_output.type == defs.OutputType.file:
if self.update_index_output.type == defs.StorageType.file:
if self.update_index_output.base_dir.strip() == "":
msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration."
raise ValueError(msg)
@ -345,6 +354,7 @@ class GraphRagConfig(BaseModel):
self._validate_root_dir()
self._validate_models()
self._validate_input_pattern()
self._validate_input_base_dir()
self._validate_reporting_base_dir()
self._validate_output_base_dir()
self._validate_multi_output_base_dirs()

View File

@ -7,36 +7,23 @@ from pydantic import BaseModel, Field
import graphrag.config.defaults as defs
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import InputFileType, InputType
from graphrag.config.enums import InputFileType
from graphrag.config.models.storage_config import StorageConfig
class InputConfig(BaseModel):
"""The default configuration section for Input."""
type: InputType = Field(
description="The input type to use.",
default=graphrag_config_defaults.input.type,
storage: StorageConfig = Field(
description="The storage configuration to use for reading input documents.",
default=StorageConfig(
base_dir=graphrag_config_defaults.input.storage.base_dir,
),
)
file_type: InputFileType = Field(
description="The input file type to use.",
default=graphrag_config_defaults.input.file_type,
)
base_dir: str = Field(
description="The input base directory to use.",
default=graphrag_config_defaults.input.base_dir,
)
connection_string: str | None = Field(
description="The azure blob storage connection string to use.",
default=graphrag_config_defaults.input.connection_string,
)
storage_account_blob_url: str | None = Field(
description="The storage account blob url to use.",
default=graphrag_config_defaults.input.storage_account_blob_url,
)
container_name: str | None = Field(
description="The azure blob storage container name to use.",
default=graphrag_config_defaults.input.container_name,
)
encoding: str = Field(
description="The input file encoding to use.",
default=defs.graphrag_config_defaults.input.encoding,

View File

@ -1,38 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, Field
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import OutputType
class OutputConfig(BaseModel):
"""The default configuration section for Output."""
type: OutputType = Field(
description="The output type to use.",
default=graphrag_config_defaults.output.type,
)
base_dir: str = Field(
description="The base directory for the output.",
default=graphrag_config_defaults.output.base_dir,
)
connection_string: str | None = Field(
description="The storage connection string to use.",
default=graphrag_config_defaults.output.connection_string,
)
container_name: str | None = Field(
description="The storage container name to use.",
default=graphrag_config_defaults.output.container_name,
)
storage_account_blob_url: str | None = Field(
description="The storage account blob url to use.",
default=graphrag_config_defaults.output.storage_account_blob_url,
)
cosmosdb_account_url: str | None = Field(
description="The cosmosdb account url to use.",
default=graphrag_config_defaults.output.cosmosdb_account_url,
)

View File

@ -0,0 +1,52 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pathlib import Path
from pydantic import BaseModel, Field, field_validator
from graphrag.config.defaults import graphrag_config_defaults
from graphrag.config.enums import StorageType
class StorageConfig(BaseModel):
"""The default configuration section for storage."""
type: StorageType = Field(
description="The storage type to use.",
default=graphrag_config_defaults.storage.type,
)
base_dir: str = Field(
description="The base directory for the output.",
default=graphrag_config_defaults.storage.base_dir,
)
# Validate the base dir for multiple OS (use Path)
# if not using a cloud storage type.
@field_validator("base_dir", mode="before")
@classmethod
def validate_base_dir(cls, value, info):
"""Ensure that base_dir is a valid filesystem path when using local storage."""
# info.data contains other field values, including 'type'
if info.data.get("type") != StorageType.file:
return value
return str(Path(value))
connection_string: str | None = Field(
description="The storage connection string to use.",
default=graphrag_config_defaults.storage.connection_string,
)
container_name: str | None = Field(
description="The storage container name to use.",
default=graphrag_config_defaults.storage.container_name,
)
storage_account_blob_url: str | None = Field(
description="The storage account blob url to use.",
default=graphrag_config_defaults.storage.storage_account_blob_url,
)
cosmosdb_account_url: str | None = Field(
description="The cosmosdb account url to use.",
default=graphrag_config_defaults.storage.cosmosdb_account_url,
)

View File

@ -22,7 +22,7 @@ async def load_csv(
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load csv inputs from a directory."""
log.info("Loading csv files from %s", config.base_dir)
log.info("Loading csv files from %s", config.storage.base_dir)
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
if group is None:

View File

@ -5,20 +5,18 @@
import logging
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import cast
import pandas as pd
from graphrag.config.enums import InputFileType, InputType
from graphrag.config.enums import InputFileType
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.csv import load_csv
from graphrag.index.input.json import load_json
from graphrag.index.input.text import load_text
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
from graphrag.storage.pipeline_storage import PipelineStorage
log = logging.getLogger(__name__)
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
@ -30,43 +28,12 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
async def create_input(
config: InputConfig,
storage: PipelineStorage,
progress_reporter: ProgressLogger | None = None,
root_dir: str | None = None,
) -> pd.DataFrame:
"""Instantiate input data for a pipeline."""
root_dir = root_dir or ""
log.info("loading input from root_dir=%s", config.base_dir)
progress_reporter = progress_reporter or NullProgressLogger()
match config.type:
case InputType.blob:
log.info("using blob storage input")
if config.container_name is None:
msg = "Container name required for blob storage"
raise ValueError(msg)
if (
config.connection_string is None
and config.storage_account_blob_url is None
):
msg = "Connection string or storage account blob url required for blob storage"
raise ValueError(msg)
storage = BlobPipelineStorage(
connection_string=config.connection_string,
storage_account_blob_url=config.storage_account_blob_url,
container_name=config.container_name,
path_prefix=config.base_dir,
)
case InputType.file:
log.info("using file storage for input")
storage = FilePipelineStorage(
root_dir=str(Path(root_dir) / (config.base_dir or ""))
)
case _:
log.info("using file storage for input")
storage = FilePipelineStorage(
root_dir=str(Path(root_dir) / (config.base_dir or ""))
)
if config.file_type in loaders:
progress = progress_reporter.child(
f"Loading Input ({config.file_type})", transient=False

View File

@ -22,7 +22,7 @@ async def load_json(
storage: PipelineStorage,
) -> pd.DataFrame:
"""Load json inputs from a directory."""
log.info("Loading json files from %s", config.base_dir)
log.info("Loading json files from %s", config.storage.base_dir)
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
if group is None:

View File

@ -33,7 +33,7 @@ async def load_files(
)
if len(files) == 0:
msg = f"No {config.file_type} files found in {config.base_dir}"
msg = f"No {config.file_type} files found in {config.storage.base_dir}"
raise ValueError(msg)
files_loaded = []

View File

@ -11,16 +11,12 @@ import traceback
from collections.abc import AsyncIterable
from dataclasses import asdict
import pandas as pd
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input.factory import create_input
from graphrag.index.run.utils import create_run_context
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.pipeline import Pipeline
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
from graphrag.index.update.incremental_index import get_delta_docs
from graphrag.logger.base import ProgressLogger
from graphrag.logger.progress import Progress
from graphrag.storage.pipeline_storage import PipelineStorage
@ -40,25 +36,17 @@ async def run_pipeline(
"""Run all workflows using a simplified pipeline."""
root_dir = config.root_dir
storage = create_storage_from_config(config.output)
input_storage = create_storage_from_config(config.input.storage)
output_storage = create_storage_from_config(config.output)
cache = create_cache_from_config(config.cache, root_dir)
dataset = await create_input(config.input, logger, root_dir)
# load existing state in case any workflows are stateful
state_json = await storage.get("context.json")
state_json = await output_storage.get("context.json")
state = json.loads(state_json) if state_json else {}
if is_update_run:
logger.info("Running incremental indexing.")
delta_dataset = await get_delta_docs(dataset, storage)
# warn on empty delta dataset
if delta_dataset.new_inputs.empty:
warning_msg = "Incremental indexing found no new documents, exiting."
logger.warning(warning_msg)
else:
update_storage = create_storage_from_config(config.update_index_output)
# we use this to store the new subset index, and will merge its content with the previous index
update_timestamp = time.strftime("%Y%m%d-%H%M%S")
@ -67,37 +55,35 @@ async def run_pipeline(
# copy the previous output to a backup folder, so we can replace it with the update
# we'll read from this later when we merge the old and new indexes
previous_storage = timestamped_storage.child("previous")
await _copy_previous_output(storage, previous_storage)
await _copy_previous_output(output_storage, previous_storage)
state["update_timestamp"] = update_timestamp
context = create_run_context(
storage=delta_storage, cache=cache, callbacks=callbacks, state=state
input_storage=input_storage,
output_storage=delta_storage,
previous_storage=previous_storage,
cache=cache,
callbacks=callbacks,
state=state,
progress_logger=logger,
)
# Run the pipeline on the new documents
async for table in _run_pipeline(
pipeline=pipeline,
config=config,
dataset=delta_dataset.new_inputs,
logger=logger,
context=context,
):
yield table
logger.success("Finished running workflows on new documents.")
else:
logger.info("Running standard indexing.")
context = create_run_context(
storage=storage, cache=cache, callbacks=callbacks, state=state
input_storage=input_storage,
output_storage=output_storage,
cache=cache,
callbacks=callbacks,
state=state,
progress_logger=logger,
)
async for table in _run_pipeline(
pipeline=pipeline,
config=config,
dataset=dataset,
logger=logger,
context=context,
):
@ -107,19 +93,15 @@ async def run_pipeline(
async def _run_pipeline(
pipeline: Pipeline,
config: GraphRagConfig,
dataset: pd.DataFrame,
logger: ProgressLogger,
context: PipelineRunContext,
) -> AsyncIterable[PipelineRunResult]:
start_time = time.time()
log.info("Final # of rows loaded: %s", len(dataset))
context.stats.num_documents = len(dataset)
last_workflow = "starting documents"
last_workflow = "<startup>"
try:
await _dump_json(context)
await write_table_to_storage(dataset, "documents", context.storage)
for name, workflow_function in pipeline.run():
last_workflow = name
@ -132,8 +114,10 @@ async def _run_pipeline(
yield PipelineRunResult(
workflow=name, result=result.result, state=context.state, errors=None
)
context.stats.workflows[name] = {"overall": time.time() - work_time}
if result.stop:
logger.info("Halting pipeline at workflow request")
break
context.stats.total_runtime = time.time() - start_time
await _dump_json(context)
@ -148,10 +132,10 @@ async def _run_pipeline(
async def _dump_json(context: PipelineRunContext) -> None:
"""Dump the stats and context state to the storage."""
await context.storage.set(
await context.output_storage.set(
"stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False)
)
await context.storage.set(
await context.output_storage.set(
"context.json", json.dumps(context.state, indent=4, ensure_ascii=False)
)

View File

@ -14,24 +14,31 @@ from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.state import PipelineState
from graphrag.index.typing.stats import PipelineRunStats
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.api import create_storage_from_config
def create_run_context(
storage: PipelineStorage | None = None,
input_storage: PipelineStorage | None = None,
output_storage: PipelineStorage | None = None,
previous_storage: PipelineStorage | None = None,
cache: PipelineCache | None = None,
callbacks: WorkflowCallbacks | None = None,
progress_logger: ProgressLogger | None = None,
stats: PipelineRunStats | None = None,
state: PipelineState | None = None,
) -> PipelineRunContext:
"""Create the run context for the pipeline."""
return PipelineRunContext(
stats=stats or PipelineRunStats(),
input_storage=input_storage or MemoryPipelineStorage(),
output_storage=output_storage or MemoryPipelineStorage(),
previous_storage=previous_storage or MemoryPipelineStorage(),
cache=cache or InMemoryCache(),
storage=storage or MemoryPipelineStorage(),
callbacks=callbacks or NoopWorkflowCallbacks(),
progress_logger=progress_logger or NullProgressLogger(),
stats=stats or PipelineRunStats(),
state=state or {},
)

View File

@ -10,6 +10,7 @@ from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.typing.state import PipelineState
from graphrag.index.typing.stats import PipelineRunStats
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
@ -18,11 +19,17 @@ class PipelineRunContext:
"""Provides the context for the current pipeline run."""
stats: PipelineRunStats
storage: PipelineStorage
input_storage: PipelineStorage
"Storage for input documents."
output_storage: PipelineStorage
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
previous_storage: PipelineStorage
"Storage for previous pipeline run when running in update mode."
cache: PipelineCache
"Cache instance for reading previous LLM responses."
callbacks: WorkflowCallbacks
"Callbacks to be called during the pipeline run."
progress_logger: ProgressLogger
"Progress logger for the pipeline run."
state: PipelineState
"Arbitrary property bag for runtime state, persistent pre-computes, or experimental features."

View File

@ -15,6 +15,8 @@ class PipelineRunStats:
num_documents: int = field(default=0)
"""Number of documents."""
update_documents: int = field(default=0)
"""Number of update documents."""
input_load_time: float = field(default=0)
"""Float representing the input load time."""

View File

@ -17,6 +17,8 @@ class WorkflowFunctionOutput:
result: Any | None
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
stop: bool = False
"""Flag to indicate if the workflow should stop after this function. This should only be used when continuation could cause an unstable failure."""
WorkflowFunction = Callable[

View File

@ -39,6 +39,12 @@ from .finalize_graph import (
from .generate_text_embeddings import (
run_workflow as run_generate_text_embeddings,
)
from .load_input_documents import (
run_workflow as run_load_input_documents,
)
from .load_update_documents import (
run_workflow as run_load_update_documents,
)
from .prune_graph import (
run_workflow as run_prune_graph,
)
@ -69,6 +75,8 @@ from .update_text_units import (
# register all of our built-in workflows at once
PipelineFactory.register_all({
"load_input_documents": run_load_input_documents,
"load_update_documents": run_load_update_documents,
"create_base_text_units": run_create_base_text_units,
"create_communities": run_create_communities,
"create_community_reports_text": run_create_community_reports_text,

View File

@ -25,7 +25,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform base text_units."""
documents = await load_table_from_storage("documents", context.storage)
documents = await load_table_from_storage("documents", context.output_storage)
chunks = config.chunks
@ -41,7 +41,7 @@ async def run_workflow(
chunk_size_includes_metadata=chunks.chunk_size_includes_metadata,
)
await write_table_to_storage(output, "text_units", context.storage)
await write_table_to_storage(output, "text_units", context.output_storage)
return WorkflowFunctionOutput(result=output)

View File

@ -24,8 +24,10 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform final communities."""
entities = await load_table_from_storage("entities", context.storage)
relationships = await load_table_from_storage("relationships", context.storage)
entities = await load_table_from_storage("entities", context.output_storage)
relationships = await load_table_from_storage(
"relationships", context.output_storage
)
max_cluster_size = config.cluster_graph.max_cluster_size
use_lcc = config.cluster_graph.use_lcc
@ -39,7 +41,7 @@ async def run_workflow(
seed=seed,
)
await write_table_to_storage(output, "communities", context.storage)
await write_table_to_storage(output, "communities", context.output_storage)
return WorkflowFunctionOutput(result=output)

View File

@ -38,14 +38,14 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
edges = await load_table_from_storage("relationships", context.storage)
entities = await load_table_from_storage("entities", context.storage)
communities = await load_table_from_storage("communities", context.storage)
edges = await load_table_from_storage("relationships", context.output_storage)
entities = await load_table_from_storage("entities", context.output_storage)
communities = await load_table_from_storage("communities", context.output_storage)
claims = None
if config.extract_claims.enabled and await storage_has_table(
"covariates", context.storage
"covariates", context.output_storage
):
claims = await load_table_from_storage("covariates", context.storage)
claims = await load_table_from_storage("covariates", context.output_storage)
community_reports_llm_settings = config.get_language_model_config(
config.community_reports.model_id
@ -68,7 +68,7 @@ async def run_workflow(
num_threads=num_threads,
)
await write_table_to_storage(output, "community_reports", context.storage)
await write_table_to_storage(output, "community_reports", context.output_storage)
return WorkflowFunctionOutput(result=output)

View File

@ -37,10 +37,10 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform community reports."""
entities = await load_table_from_storage("entities", context.storage)
communities = await load_table_from_storage("communities", context.storage)
entities = await load_table_from_storage("entities", context.output_storage)
communities = await load_table_from_storage("communities", context.output_storage)
text_units = await load_table_from_storage("text_units", context.storage)
text_units = await load_table_from_storage("text_units", context.output_storage)
community_reports_llm_settings = config.get_language_model_config(
config.community_reports.model_id
@ -62,7 +62,7 @@ async def run_workflow(
num_threads=num_threads,
)
await write_table_to_storage(output, "community_reports", context.storage)
await write_table_to_storage(output, "community_reports", context.output_storage)
return WorkflowFunctionOutput(result=output)

View File

@ -17,12 +17,12 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform final documents."""
documents = await load_table_from_storage("documents", context.storage)
text_units = await load_table_from_storage("text_units", context.storage)
documents = await load_table_from_storage("documents", context.output_storage)
text_units = await load_table_from_storage("text_units", context.output_storage)
output = create_final_documents(documents, text_units)
await write_table_to_storage(output, "documents", context.storage)
await write_table_to_storage(output, "documents", context.output_storage)
return WorkflowFunctionOutput(result=output)

View File

@ -21,16 +21,18 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to transform the text units."""
text_units = await load_table_from_storage("text_units", context.storage)
final_entities = await load_table_from_storage("entities", context.storage)
text_units = await load_table_from_storage("text_units", context.output_storage)
final_entities = await load_table_from_storage("entities", context.output_storage)
final_relationships = await load_table_from_storage(
"relationships", context.storage
"relationships", context.output_storage
)
final_covariates = None
if config.extract_claims.enabled and await storage_has_table(
"covariates", context.storage
"covariates", context.output_storage
):
final_covariates = await load_table_from_storage("covariates", context.storage)
final_covariates = await load_table_from_storage(
"covariates", context.output_storage
)
output = create_final_text_units(
text_units,
@ -39,7 +41,7 @@ async def run_workflow(
final_covariates,
)
await write_table_to_storage(output, "text_units", context.storage)
await write_table_to_storage(output, "text_units", context.output_storage)
return WorkflowFunctionOutput(result=output)

View File

@ -26,7 +26,9 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to extract and format covariates."""
text_units = await load_table_from_storage("text_units", context.storage)
output = None
if config.extract_claims.enabled:
text_units = await load_table_from_storage("text_units", context.output_storage)
extract_claims_llm_settings = config.get_language_model_config(
config.extract_claims.model_id
@ -49,7 +51,7 @@ async def run_workflow(
num_threads=num_threads,
)
await write_table_to_storage(output, "covariates", context.storage)
await write_table_to_storage(output, "covariates", context.output_storage)
return WorkflowFunctionOutput(result=output)

View File

@ -27,7 +27,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
text_units = await load_table_from_storage("text_units", context.storage)
text_units = await load_table_from_storage("text_units", context.output_storage)
extract_graph_llm_settings = config.get_language_model_config(
config.extract_graph.model_id
@ -55,13 +55,15 @@ async def run_workflow(
summarization_num_threads=summarization_llm_settings.concurrent_requests,
)
await write_table_to_storage(entities, "entities", context.storage)
await write_table_to_storage(relationships, "relationships", context.storage)
await write_table_to_storage(entities, "entities", context.output_storage)
await write_table_to_storage(relationships, "relationships", context.output_storage)
if config.snapshots.raw_graph:
await write_table_to_storage(raw_entities, "raw_entities", context.storage)
await write_table_to_storage(
raw_relationships, "raw_relationships", context.storage
raw_entities, "raw_entities", context.output_storage
)
await write_table_to_storage(
raw_relationships, "raw_relationships", context.output_storage
)
return WorkflowFunctionOutput(

View File

@ -22,7 +22,7 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
text_units = await load_table_from_storage("text_units", context.storage)
text_units = await load_table_from_storage("text_units", context.output_storage)
entities, relationships = await extract_graph_nlp(
text_units,
@ -30,8 +30,8 @@ async def run_workflow(
extraction_config=config.extract_graph_nlp,
)
await write_table_to_storage(entities, "entities", context.storage)
await write_table_to_storage(relationships, "relationships", context.storage)
await write_table_to_storage(entities, "entities", context.output_storage)
await write_table_to_storage(relationships, "relationships", context.output_storage)
return WorkflowFunctionOutput(
result={

View File

@ -15,6 +15,7 @@ class PipelineFactory:
"""A factory class for workflow pipelines."""
workflows: ClassVar[dict[str, WorkflowFunction]] = {}
pipelines: ClassVar[dict[str, list[str]]] = {}
@classmethod
def register(cls, name: str, workflow: WorkflowFunction):
@ -27,53 +28,35 @@ class PipelineFactory:
for name, workflow in workflows.items():
cls.register(name, workflow)
@classmethod
def register_pipeline(cls, name: str, workflows: list[str]):
"""Register a new pipeline method as a list of workflow names."""
cls.pipelines[name] = workflows
@classmethod
def create_pipeline(
cls,
config: GraphRagConfig,
method: IndexingMethod = IndexingMethod.Standard,
is_update_run: bool = False,
method: IndexingMethod | str = IndexingMethod.Standard,
) -> Pipeline:
"""Create a pipeline generator."""
workflows = _get_workflows_list(config, method, is_update_run)
workflows = config.workflows or cls.pipelines.get(method, [])
return Pipeline([(name, cls.workflows[name]) for name in workflows])
def _get_workflows_list(
config: GraphRagConfig,
method: IndexingMethod = IndexingMethod.Standard,
is_update_run: bool = False,
) -> list[str]:
"""Return a list of workflows for the indexing pipeline."""
update_workflows = [
"update_final_documents",
"update_entities_relationships",
"update_text_units",
"update_covariates",
"update_communities",
"update_community_reports",
"update_text_embeddings",
"update_clean_state",
]
if config.workflows:
return config.workflows
match method:
case IndexingMethod.Standard:
return [
# --- Register default implementations ---
_standard_workflows = [
"create_base_text_units",
"create_final_documents",
"extract_graph",
"finalize_graph",
*(["extract_covariates"] if config.extract_claims.enabled else []),
"extract_covariates",
"create_communities",
"create_final_text_units",
"create_community_reports",
"generate_text_embeddings",
*(update_workflows if is_update_run else []),
]
case IndexingMethod.Fast:
return [
]
_fast_workflows = [
"create_base_text_units",
"create_final_documents",
"extract_graph_nlp",
@ -83,5 +66,28 @@ def _get_workflows_list(
"create_final_text_units",
"create_community_reports_text",
"generate_text_embeddings",
*(update_workflows if is_update_run else []),
]
]
_update_workflows = [
"update_final_documents",
"update_entities_relationships",
"update_text_units",
"update_covariates",
"update_communities",
"update_community_reports",
"update_text_embeddings",
"update_clean_state",
]
PipelineFactory.register_pipeline(
IndexingMethod.Standard, ["load_input_documents", *_standard_workflows]
)
PipelineFactory.register_pipeline(
IndexingMethod.Fast, ["load_input_documents", *_fast_workflows]
)
PipelineFactory.register_pipeline(
IndexingMethod.StandardUpdate,
["load_update_documents", *_standard_workflows, *_update_workflows],
)
PipelineFactory.register_pipeline(
IndexingMethod.FastUpdate,
["load_update_documents", *_fast_workflows, *_update_workflows],
)

View File

@ -22,8 +22,10 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
entities = await load_table_from_storage("entities", context.storage)
relationships = await load_table_from_storage("relationships", context.storage)
entities = await load_table_from_storage("entities", context.output_storage)
relationships = await load_table_from_storage(
"relationships", context.output_storage
)
final_entities, final_relationships = finalize_graph(
entities,
@ -33,8 +35,10 @@ async def run_workflow(
layout_enabled=config.umap.enabled,
)
await write_table_to_storage(final_entities, "entities", context.storage)
await write_table_to_storage(final_relationships, "relationships", context.storage)
await write_table_to_storage(final_entities, "entities", context.output_storage)
await write_table_to_storage(
final_relationships, "relationships", context.output_storage
)
if config.snapshots.graphml:
# todo: extract graphs at each level, and add in meta like descriptions
@ -43,7 +47,7 @@ async def run_workflow(
await snapshot_graphml(
graph,
name="graph",
storage=context.storage,
storage=context.output_storage,
)
return WorkflowFunctionOutput(

View File

@ -43,17 +43,19 @@ async def run_workflow(
text_units = None
entities = None
community_reports = None
if await storage_has_table("documents", context.storage):
documents = await load_table_from_storage("documents", context.storage)
if await storage_has_table("relationships", context.storage):
relationships = await load_table_from_storage("relationships", context.storage)
if await storage_has_table("text_units", context.storage):
text_units = await load_table_from_storage("text_units", context.storage)
if await storage_has_table("entities", context.storage):
entities = await load_table_from_storage("entities", context.storage)
if await storage_has_table("community_reports", context.storage):
if await storage_has_table("documents", context.output_storage):
documents = await load_table_from_storage("documents", context.output_storage)
if await storage_has_table("relationships", context.output_storage):
relationships = await load_table_from_storage(
"relationships", context.output_storage
)
if await storage_has_table("text_units", context.output_storage):
text_units = await load_table_from_storage("text_units", context.output_storage)
if await storage_has_table("entities", context.output_storage):
entities = await load_table_from_storage("entities", context.output_storage)
if await storage_has_table("community_reports", context.output_storage):
community_reports = await load_table_from_storage(
"community_reports", context.storage
"community_reports", context.output_storage
)
embedded_fields = config.embed_text.names
@ -76,7 +78,7 @@ async def run_workflow(
await write_table_to_storage(
table,
f"embeddings.{name}",
context.storage,
context.output_storage,
)
return WorkflowFunctionOutput(result=output)

View File

@ -0,0 +1,45 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run_workflow method definition."""
import logging
import pandas as pd
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.factory import create_input
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.storage import write_table_to_storage
log = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Load and parse input documents into a standard format."""
output = await load_input_documents(
config.input,
context.input_storage,
context.progress_logger,
)
log.info("Final # of rows loaded: %s", len(output))
context.stats.num_documents = len(output)
await write_table_to_storage(output, "documents", context.output_storage)
return WorkflowFunctionOutput(result=output)
async def load_input_documents(
config: InputConfig, storage: PipelineStorage, progress_logger: ProgressLogger
) -> pd.DataFrame:
"""Load and parse input documents into a standard format."""
return await create_input(config, storage, progress_logger)

View File

@ -0,0 +1,59 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing run_workflow method definition."""
import logging
import pandas as pd
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.index.input.factory import create_input
from graphrag.index.typing.context import PipelineRunContext
from graphrag.index.typing.workflow import WorkflowFunctionOutput
from graphrag.index.update.incremental_index import get_delta_docs
from graphrag.logger.base import ProgressLogger
from graphrag.storage.pipeline_storage import PipelineStorage
from graphrag.utils.storage import write_table_to_storage
log = logging.getLogger(__name__)
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""Load and parse update-only input documents into a standard format."""
output = await load_update_documents(
config.input,
context.input_storage,
context.previous_storage,
context.progress_logger,
)
log.info("Final # of update rows loaded: %s", len(output))
context.stats.update_documents = len(output)
if len(output) == 0:
log.warning("No new update documents found.")
context.progress_logger.warning("No new update documents found.")
return WorkflowFunctionOutput(result=None, stop=True)
await write_table_to_storage(output, "documents", context.output_storage)
return WorkflowFunctionOutput(result=output)
async def load_update_documents(
config: InputConfig,
input_storage: PipelineStorage,
previous_storage: PipelineStorage,
progress_logger: ProgressLogger,
) -> pd.DataFrame:
"""Load and parse update-only input documents into a standard format."""
input_documents = await create_input(config, input_storage, progress_logger)
# previous storage is the output of the previous run
# we'll use this to diff the input from the prior
delta_documents = await get_delta_docs(input_documents, previous_storage)
return delta_documents.new_inputs

View File

@ -20,8 +20,10 @@ async def run_workflow(
context: PipelineRunContext,
) -> WorkflowFunctionOutput:
"""All the steps to create the base entity graph."""
entities = await load_table_from_storage("entities", context.storage)
relationships = await load_table_from_storage("relationships", context.storage)
entities = await load_table_from_storage("entities", context.output_storage)
relationships = await load_table_from_storage(
"relationships", context.output_storage
)
pruned_entities, pruned_relationships = prune_graph(
entities,
@ -29,8 +31,10 @@ async def run_workflow(
pruning_config=config.prune_graph,
)
await write_table_to_storage(pruned_entities, "entities", context.storage)
await write_table_to_storage(pruned_relationships, "relationships", context.storage)
await write_table_to_storage(pruned_entities, "entities", context.output_storage)
await write_table_to_storage(
pruned_relationships, "relationships", context.output_storage
)
return WorkflowFunctionOutput(
result={

View File

@ -21,6 +21,7 @@ from graphrag.prompt_tune.defaults import (
K,
)
from graphrag.prompt_tune.types import DocSelectionType
from graphrag.utils.api import create_storage_from_config
def _sample_chunks_from_embeddings(
@ -37,7 +38,6 @@ def _sample_chunks_from_embeddings(
async def load_docs_in_chunks(
root: str,
config: GraphRagConfig,
select_method: DocSelectionType,
limit: int,
@ -51,7 +51,8 @@ async def load_docs_in_chunks(
embeddings_llm_settings = config.get_language_model_config(
config.embed_text.model_id
)
dataset = await create_input(config.input, logger, root)
input_storage = create_storage_from_config(config.input.storage)
dataset = await create_input(config.input, input_storage, logger)
chunk_config = config.chunks
chunks_df = create_base_text_units(
documents=dataset,

View File

@ -7,7 +7,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, ClassVar
from graphrag.config.enums import OutputType
from graphrag.config.enums import StorageType
from graphrag.storage.blob_pipeline_storage import create_blob_storage
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
from graphrag.storage.file_pipeline_storage import create_file_storage
@ -35,17 +35,17 @@ class StorageFactory:
@classmethod
def create_storage(
cls, storage_type: OutputType | str, kwargs: dict
cls, storage_type: StorageType | str, kwargs: dict
) -> PipelineStorage:
"""Create or get a storage object from the provided type."""
match storage_type:
case OutputType.blob:
case StorageType.blob:
return create_blob_storage(**kwargs)
case OutputType.cosmosdb:
case StorageType.cosmosdb:
return create_cosmosdb_storage(**kwargs)
case OutputType.file:
case StorageType.file:
return create_file_storage(**kwargs)
case OutputType.memory:
case StorageType.memory:
return MemoryPipelineStorage()
case _:
if storage_type in cls.storage_types:

View File

@ -10,7 +10,7 @@ from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.config.embeddings import create_collection_name
from graphrag.config.models.cache_config import CacheConfig
from graphrag.config.models.output_config import OutputConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.data_model.types import TextEmbedder
from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage
@ -238,7 +238,7 @@ def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
return None
def create_storage_from_config(output: OutputConfig) -> PipelineStorage:
def create_storage_from_config(output: StorageConfig) -> PipelineStorage:
"""Create a storage object from the config."""
storage_config = output.model_dump()
return StorageFactory().create_storage(

View File

@ -9,11 +9,13 @@ vector_store:
container_name: "azure_ci"
input:
storage:
type: blob
file_type: text
connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING}
container_name: azurefixture
base_dir: input
file_type: text
cache:
type: blob
@ -21,7 +23,7 @@ cache:
container_name: cicache
base_dir: cache_azure_ai
storage:
output:
type: blob
connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING}
container_name: azurefixture

View File

@ -2,9 +2,15 @@
"input_path": "./tests/fixtures/min-csv",
"input_file_type": "text",
"workflow_config": {
"load_input_documents": {
"max_runtime": 30
},
"create_base_text_units": {
"max_runtime": 30
},
"extract_covariates": {
"max_runtime": 10
},
"extract_graph": {
"max_runtime": 500
},

View File

@ -2,6 +2,9 @@
"input_path": "./tests/fixtures/text",
"input_file_type": "text",
"workflow_config": {
"load_input_documents": {
"max_runtime": 30
},
"create_base_text_units": {
"max_runtime": 30
},

View File

@ -9,7 +9,7 @@ import sys
import pytest
from graphrag.config.enums import OutputType
from graphrag.config.enums import StorageType
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
from graphrag.storage.factory import StorageFactory
@ -29,7 +29,7 @@ def test_create_blob_storage():
"base_dir": "testbasedir",
"container_name": "testcontainer",
}
storage = StorageFactory.create_storage(OutputType.blob, kwargs)
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
assert isinstance(storage, BlobPipelineStorage)
@ -44,19 +44,19 @@ def test_create_cosmosdb_storage():
"base_dir": "testdatabase",
"container_name": "testcontainer",
}
storage = StorageFactory.create_storage(OutputType.cosmosdb, kwargs)
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
assert isinstance(storage, CosmosDBPipelineStorage)
def test_create_file_storage():
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
storage = StorageFactory.create_storage(OutputType.file, kwargs)
storage = StorageFactory.create_storage(StorageType.file, kwargs)
assert isinstance(storage, FilePipelineStorage)
def test_create_memory_storage():
kwargs = {"type": "memory"}
storage = StorageFactory.create_storage(OutputType.memory, kwargs)
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
assert isinstance(storage, MemoryPipelineStorage)

View File

@ -24,10 +24,10 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.config.models.local_search_config import LocalSearchConfig
from graphrag.config.models.output_config import OutputConfig
from graphrag.config.models.prune_graph_config import PruneGraphConfig
from graphrag.config.models.reporting_config import ReportingConfig
from graphrag.config.models.snapshots_config import SnapshotsConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.config.models.summarize_descriptions_config import (
SummarizeDescriptionsConfig,
)
@ -134,7 +134,7 @@ def assert_reporting_configs(
assert actual.storage_account_blob_url == expected.storage_account_blob_url
def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None:
def assert_output_configs(actual: StorageConfig, expected: StorageConfig) -> None:
assert expected.type == actual.type
assert expected.base_dir == actual.base_dir
assert expected.connection_string == actual.connection_string
@ -143,7 +143,9 @@ def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None:
assert expected.cosmosdb_account_url == actual.cosmosdb_account_url
def assert_update_output_configs(actual: OutputConfig, expected: OutputConfig) -> None:
def assert_update_output_configs(
actual: StorageConfig, expected: StorageConfig
) -> None:
assert expected.type == actual.type
assert expected.base_dir == actual.base_dir
assert expected.connection_string == actual.connection_string
@ -162,12 +164,15 @@ def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None:
def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None:
assert actual.type == expected.type
assert actual.storage.type == expected.storage.type
assert actual.file_type == expected.file_type
assert actual.base_dir == expected.base_dir
assert actual.connection_string == expected.connection_string
assert actual.storage_account_blob_url == expected.storage_account_blob_url
assert actual.container_name == expected.container_name
assert actual.storage.base_dir == expected.storage.base_dir
assert actual.storage.connection_string == expected.storage.connection_string
assert (
actual.storage.storage_account_blob_url
== expected.storage.storage_account_blob_url
)
assert actual.storage.container_name == expected.storage.container_name
assert actual.encoding == expected.encoding
assert actual.file_pattern == expected.file_pattern
assert actual.file_filter == expected.file_filter

View File

@ -1,56 +1,66 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.config.enums import InputFileType, InputType
from graphrag.config.enums import InputFileType
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.index.input.factory import create_input
from graphrag.utils.api import create_storage_from_config
async def test_csv_loader_one_file():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-csv",
),
file_type=InputFileType.csv,
file_pattern=".*\\.csv$",
base_dir="tests/unit/indexing/input/data/one-csv",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (2, 4)
assert documents["title"].iloc[0] == "input.csv"
async def test_csv_loader_one_file_with_title():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-csv",
),
file_type=InputFileType.csv,
file_pattern=".*\\.csv$",
base_dir="tests/unit/indexing/input/data/one-csv",
title_column="title",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (2, 4)
assert documents["title"].iloc[0] == "Hello"
async def test_csv_loader_one_file_with_metadata():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-csv",
),
file_type=InputFileType.csv,
file_pattern=".*\\.csv$",
base_dir="tests/unit/indexing/input/data/one-csv",
title_column="title",
metadata=["title"],
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (2, 5)
assert documents["metadata"][0] == {"title": "Hello"}
async def test_csv_loader_multiple_files():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/multiple-csvs",
),
file_type=InputFileType.csv,
file_pattern=".*\\.csv$",
base_dir="tests/unit/indexing/input/data/multiple-csvs",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (4, 4)

View File

@ -1,31 +1,37 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.config.enums import InputFileType, InputType
from graphrag.config.enums import InputFileType
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.index.input.factory import create_input
from graphrag.utils.api import create_storage_from_config
async def test_json_loader_one_file_one_object():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-json-one-object",
),
file_type=InputFileType.json,
file_pattern=".*\\.json$",
base_dir="tests/unit/indexing/input/data/one-json-one-object",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (1, 4)
assert documents["title"].iloc[0] == "input.json"
async def test_json_loader_one_file_multiple_objects():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-json-multiple-objects",
),
file_type=InputFileType.json,
file_pattern=".*\\.json$",
base_dir="tests/unit/indexing/input/data/one-json-multiple-objects",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
print(documents)
assert documents.shape == (3, 4)
assert documents["title"].iloc[0] == "input.json"
@ -33,37 +39,43 @@ async def test_json_loader_one_file_multiple_objects():
async def test_json_loader_one_file_with_title():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-json-one-object",
),
file_type=InputFileType.json,
file_pattern=".*\\.json$",
base_dir="tests/unit/indexing/input/data/one-json-one-object",
title_column="title",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (1, 4)
assert documents["title"].iloc[0] == "Hello"
async def test_json_loader_one_file_with_metadata():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-json-one-object",
),
file_type=InputFileType.json,
file_pattern=".*\\.json$",
base_dir="tests/unit/indexing/input/data/one-json-one-object",
title_column="title",
metadata=["title"],
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (1, 5)
assert documents["metadata"][0] == {"title": "Hello"}
async def test_json_loader_multiple_files():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/multiple-jsons",
),
file_type=InputFileType.json,
file_pattern=".*\\.json$",
base_dir="tests/unit/indexing/input/data/multiple-jsons",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (4, 4)

View File

@ -1,32 +1,38 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.config.enums import InputFileType, InputType
from graphrag.config.enums import InputFileType
from graphrag.config.models.input_config import InputConfig
from graphrag.config.models.storage_config import StorageConfig
from graphrag.index.input.factory import create_input
from graphrag.utils.api import create_storage_from_config
async def test_txt_loader_one_file():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-txt",
),
file_type=InputFileType.text,
file_pattern=".*\\.txt$",
base_dir="tests/unit/indexing/input/data/one-txt",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (1, 4)
assert documents["title"].iloc[0] == "input.txt"
async def test_txt_loader_one_file_with_metadata():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/one-txt",
),
file_type=InputFileType.text,
file_pattern=".*\\.txt$",
base_dir="tests/unit/indexing/input/data/one-txt",
metadata=["title"],
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (1, 5)
# unlike csv, we cannot set the title to anything other than the filename
assert documents["metadata"][0] == {"title": "input.txt"}
@ -34,10 +40,12 @@ async def test_txt_loader_one_file_with_metadata():
async def test_txt_loader_multiple_files():
config = InputConfig(
type=InputType.file,
storage=StorageConfig(
base_dir="tests/unit/indexing/input/data/multiple-txts",
),
file_type=InputFileType.text,
file_pattern=".*\\.txt$",
base_dir="tests/unit/indexing/input/data/multiple-txts",
)
documents = await create_input(config=config)
storage = create_storage_from_config(config.storage)
documents = await create_input(config=config, storage=storage)
assert documents.shape == (2, 4)

View File

@ -23,7 +23,7 @@ async def test_create_base_text_units():
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.storage)
actual = await load_table_from_storage("text_units", context.output_storage)
compare_outputs(actual, expected, columns=["text", "document_ids", "n_tokens"])
@ -43,7 +43,7 @@ async def test_create_base_text_units_metadata():
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.storage)
actual = await load_table_from_storage("text_units", context.output_storage)
compare_outputs(actual, expected)
@ -63,6 +63,6 @@ async def test_create_base_text_units_metadata_included_in_chunk():
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.storage)
actual = await load_table_from_storage("text_units", context.output_storage)
# only check the columns from the base workflow - our expected table is the final and will have more
compare_outputs(actual, expected, columns=["text", "document_ids", "n_tokens"])

View File

@ -33,7 +33,7 @@ async def test_create_communities():
context,
)
actual = await load_table_from_storage("communities", context.storage)
actual = await load_table_from_storage("communities", context.output_storage)
columns = list(expected.columns.values)
# don't compare period since it is created with the current date each time

View File

@ -66,7 +66,7 @@ async def test_create_community_reports():
await run_workflow(config, context)
actual = await load_table_from_storage("community_reports", context.storage)
actual = await load_table_from_storage("community_reports", context.output_storage)
assert len(actual.columns) == len(expected.columns)

View File

@ -28,7 +28,7 @@ async def test_create_final_documents():
await run_workflow(config, context)
actual = await load_table_from_storage("documents", context.storage)
actual = await load_table_from_storage("documents", context.output_storage)
compare_outputs(actual, expected)
@ -47,11 +47,11 @@ async def test_create_final_documents_with_metadata_column():
# simulate the metadata construction during initial input loading
await update_document_metadata(config.input.metadata, context)
expected = await load_table_from_storage("documents", context.storage)
expected = await load_table_from_storage("documents", context.output_storage)
await run_workflow(config, context)
actual = await load_table_from_storage("documents", context.storage)
actual = await load_table_from_storage("documents", context.output_storage)
compare_outputs(actual, expected)

View File

@ -33,7 +33,7 @@ async def test_create_final_text_units():
await run_workflow(config, context)
actual = await load_table_from_storage("text_units", context.storage)
actual = await load_table_from_storage("text_units", context.output_storage)
for column in TEXT_UNITS_FINAL_COLUMNS:
assert column in actual.columns

View File

@ -37,6 +37,7 @@ async def test_extract_covariates():
).model_dump()
llm_settings["type"] = ModelType.MockChat
llm_settings["responses"] = MOCK_LLM_RESPONSES
config.extract_claims.enabled = True
config.extract_claims.strategy = {
"type": "graph_intelligence",
"llm": llm_settings,
@ -45,7 +46,7 @@ async def test_extract_covariates():
await run_workflow(config, context)
actual = await load_table_from_storage("covariates", context.storage)
actual = await load_table_from_storage("covariates", context.output_storage)
for column in COVARIATES_FINAL_COLUMNS:
assert column in actual.columns

View File

@ -63,8 +63,10 @@ async def test_extract_graph():
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
edges_actual = await load_table_from_storage("relationships", context.storage)
nodes_actual = await load_table_from_storage("entities", context.output_storage)
edges_actual = await load_table_from_storage(
"relationships", context.output_storage
)
assert len(nodes_actual.columns) == 5
assert len(edges_actual.columns) == 5

View File

@ -22,8 +22,10 @@ async def test_extract_graph_nlp():
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
edges_actual = await load_table_from_storage("relationships", context.storage)
nodes_actual = await load_table_from_storage("entities", context.output_storage)
edges_actual = await load_table_from_storage(
"relationships", context.output_storage
)
# this will be the raw count of entities and edges with no pruning
# with NLP it is deterministic, so we can assert exact row counts

View File

@ -25,8 +25,10 @@ async def test_finalize_graph():
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
edges_actual = await load_table_from_storage("relationships", context.storage)
nodes_actual = await load_table_from_storage("entities", context.output_storage)
edges_actual = await load_table_from_storage(
"relationships", context.output_storage
)
assert len(nodes_actual) == 291
assert len(edges_actual) == 452
@ -51,8 +53,10 @@ async def test_finalize_graph_umap():
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
edges_actual = await load_table_from_storage("relationships", context.storage)
nodes_actual = await load_table_from_storage("entities", context.output_storage)
edges_actual = await load_table_from_storage(
"relationships", context.output_storage
)
assert len(nodes_actual) == 291
assert len(edges_actual) == 452
@ -75,8 +79,8 @@ async def _prep_tables():
# edit the tables to eliminate final fields that wouldn't be on the inputs
entities = load_test_table("entities")
entities.drop(columns=["x", "y", "degree"], inplace=True)
await write_table_to_storage(entities, "entities", context.storage)
await write_table_to_storage(entities, "entities", context.output_storage)
relationships = load_test_table("relationships")
relationships.drop(columns=["combined_degree"], inplace=True)
await write_table_to_storage(relationships, "relationships", context.storage)
await write_table_to_storage(relationships, "relationships", context.output_storage)
return context

View File

@ -44,14 +44,14 @@ async def test_generate_text_embeddings():
await run_workflow(config, context)
parquet_files = context.storage.keys()
parquet_files = context.output_storage.keys()
for field in all_embeddings:
assert f"embeddings.{field}.parquet" in parquet_files
# entity description should always be here, let's assert its format
entity_description_embeddings = await load_table_from_storage(
"embeddings.entity.description", context.storage
"embeddings.entity.description", context.output_storage
)
assert len(entity_description_embeddings.columns) == 2
@ -60,7 +60,7 @@ async def test_generate_text_embeddings():
# every other embedding is optional but we've turned them all on, so check a random one
document_text_embeddings = await load_table_from_storage(
"embeddings.document.text", context.storage
"embeddings.document.text", context.output_storage
)
assert len(document_text_embeddings.columns) == 2

View File

@ -26,6 +26,6 @@ async def test_prune_graph():
await run_workflow(config, context)
nodes_actual = await load_table_from_storage("entities", context.storage)
nodes_actual = await load_table_from_storage("entities", context.output_storage)
assert len(nodes_actual) == 21

View File

@ -38,12 +38,12 @@ async def create_test_context(storage: list[str] | None = None) -> PipelineRunCo
# always set the input docs, but since our stored table is final, drop what wouldn't be in the original source input
input = load_test_table("documents")
input.drop(columns=["text_unit_ids"], inplace=True)
await write_table_to_storage(input, "documents", context.storage)
await write_table_to_storage(input, "documents", context.output_storage)
if storage:
for name in storage:
table = load_test_table(name)
await write_table_to_storage(table, name, context.storage)
await write_table_to_storage(table, name, context.output_storage)
return context
@ -86,8 +86,8 @@ def compare_outputs(
async def update_document_metadata(metadata: list[str], context: PipelineRunContext):
"""Takes the default documents and adds the configured metadata columns for later parsing by the text units and final documents workflows."""
documents = await load_table_from_storage("documents", context.storage)
documents = await load_table_from_storage("documents", context.output_storage)
documents["metadata"] = documents[metadata].apply(lambda row: row.to_dict(), axis=1)
await write_table_to_storage(
documents, "documents", context.storage
documents, "documents", context.output_storage
) # write to the runtime context storage only