mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Pipeline registration (#1940)
* Move covariate run conditional * All pipeline registration * Fix method name construction * Rename context storage -> output_storage * Rename OutputConfig as generic StorageConfig * Reuse Storage model under InputConfig * Move input storage creation out of document loading * Move document loading into workflows * Semver * Fix smoke test config for new workflows * Fix unit tests --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
17e431cf42
commit
1df89727c3
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Allow injection of custom pipelines."
|
||||
}
|
@ -27,7 +27,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
async def build_index(
|
||||
config: GraphRagConfig,
|
||||
method: IndexingMethod = IndexingMethod.Standard,
|
||||
method: IndexingMethod | str = IndexingMethod.Standard,
|
||||
is_update_run: bool = False,
|
||||
memory_profile: bool = False,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
@ -65,7 +65,9 @@ async def build_index(
|
||||
if memory_profile:
|
||||
log.warning("New pipeline does not yet support memory profiling.")
|
||||
|
||||
pipeline = PipelineFactory.create_pipeline(config, method, is_update_run)
|
||||
# todo: this could propagate out to the cli for better clarity, but will be a breaking api change
|
||||
method = _get_method(method, is_update_run)
|
||||
pipeline = PipelineFactory.create_pipeline(config, method)
|
||||
|
||||
workflow_callbacks.pipeline_start(pipeline.names())
|
||||
|
||||
@ -90,3 +92,8 @@ async def build_index(
|
||||
def register_workflow_function(name: str, workflow: WorkflowFunction):
|
||||
"""Register a custom workflow function. You can then include the name in the settings.yaml workflows list."""
|
||||
PipelineFactory.register(name, workflow)
|
||||
|
||||
|
||||
def _get_method(method: IndexingMethod | str, is_update_run: bool) -> str:
|
||||
m = method.value if isinstance(method, IndexingMethod) else method
|
||||
return f"{m}-update" if is_update_run else m
|
||||
|
@ -52,7 +52,6 @@ from graphrag.prompt_tune.types import DocSelectionType
|
||||
async def generate_indexing_prompts(
|
||||
config: GraphRagConfig,
|
||||
logger: ProgressLogger,
|
||||
root: str,
|
||||
chunk_size: PositiveInt = graphrag_config_defaults.chunks.size,
|
||||
overlap: Annotated[
|
||||
int, annotated_types.Gt(-1)
|
||||
@ -93,7 +92,6 @@ async def generate_indexing_prompts(
|
||||
# Retrieve documents
|
||||
logger.info("Chunking documents...")
|
||||
doc_list = await load_docs_in_chunks(
|
||||
root=root,
|
||||
config=config,
|
||||
limit=limit,
|
||||
select_method=selection_method,
|
||||
|
@ -80,7 +80,6 @@ def index_cli(
|
||||
cli_overrides["reporting.base_dir"] = str(output_dir)
|
||||
cli_overrides["update_index_output.base_dir"] = str(output_dir)
|
||||
config = load_config(root_dir, config_filepath, cli_overrides)
|
||||
|
||||
_run_index(
|
||||
config=config,
|
||||
method=method,
|
||||
|
@ -86,7 +86,6 @@ async def prompt_tune(
|
||||
|
||||
prompts = await api.generate_indexing_prompts(
|
||||
config=graph_config,
|
||||
root=str(root_path),
|
||||
logger=progress_logger,
|
||||
chunk_size=chunk_size,
|
||||
overlap=overlap,
|
||||
|
@ -14,11 +14,10 @@ from graphrag.config.enums import (
|
||||
CacheType,
|
||||
ChunkStrategyType,
|
||||
InputFileType,
|
||||
InputType,
|
||||
ModelType,
|
||||
NounPhraseExtractorType,
|
||||
OutputType,
|
||||
ReportingType,
|
||||
StorageType,
|
||||
)
|
||||
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
|
||||
EN_STOP_WORDS,
|
||||
@ -234,16 +233,31 @@ class GlobalSearchDefaults:
|
||||
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
|
||||
|
||||
|
||||
@dataclass
|
||||
class StorageDefaults:
|
||||
"""Default values for storage."""
|
||||
|
||||
type = StorageType.file
|
||||
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
storage_account_blob_url: None = None
|
||||
cosmosdb_account_url: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputStorageDefaults(StorageDefaults):
|
||||
"""Default values for input storage."""
|
||||
|
||||
base_dir: str = "input"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InputDefaults:
|
||||
"""Default values for input."""
|
||||
|
||||
type = InputType.file
|
||||
storage: InputStorageDefaults = field(default_factory=InputStorageDefaults)
|
||||
file_type = InputFileType.text
|
||||
base_dir: str = "input"
|
||||
connection_string: None = None
|
||||
storage_account_blob_url: None = None
|
||||
container_name: None = None
|
||||
encoding: str = "utf-8"
|
||||
file_pattern: str = ""
|
||||
file_filter: None = None
|
||||
@ -301,15 +315,10 @@ class LocalSearchDefaults:
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputDefaults:
|
||||
class OutputDefaults(StorageDefaults):
|
||||
"""Default values for output."""
|
||||
|
||||
type = OutputType.file
|
||||
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
storage_account_blob_url: None = None
|
||||
cosmosdb_account_url: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -364,14 +373,10 @@ class UmapDefaults:
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateIndexOutputDefaults:
|
||||
class UpdateIndexOutputDefaults(StorageDefaults):
|
||||
"""Default values for update index output."""
|
||||
|
||||
type = OutputType.file
|
||||
base_dir: str = "update_output"
|
||||
connection_string: None = None
|
||||
container_name: None = None
|
||||
storage_account_blob_url: None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -395,6 +400,7 @@ class GraphRagConfigDefaults:
|
||||
root_dir: str = ""
|
||||
models: dict = field(default_factory=dict)
|
||||
reporting: ReportingDefaults = field(default_factory=ReportingDefaults)
|
||||
storage: StorageDefaults = field(default_factory=StorageDefaults)
|
||||
output: OutputDefaults = field(default_factory=OutputDefaults)
|
||||
outputs: None = None
|
||||
update_index_output: UpdateIndexOutputDefaults = field(
|
||||
|
@ -42,20 +42,7 @@ class InputFileType(str, Enum):
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class InputType(str, Enum):
|
||||
"""The input type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
"""The file storage type."""
|
||||
blob = "blob"
|
||||
"""The blob storage type."""
|
||||
|
||||
def __repr__(self):
|
||||
"""Get a string representation."""
|
||||
return f'"{self.value}"'
|
||||
|
||||
|
||||
class OutputType(str, Enum):
|
||||
class StorageType(str, Enum):
|
||||
"""The output type for the pipeline."""
|
||||
|
||||
file = "file"
|
||||
@ -152,6 +139,10 @@ class IndexingMethod(str, Enum):
|
||||
"""Traditional GraphRAG indexing, with all graph construction and summarization performed by a language model."""
|
||||
Fast = "fast"
|
||||
"""Fast indexing, using NLP for graph construction and language model for summarization."""
|
||||
StandardUpdate = "standard-update"
|
||||
"""Incremental update with standard indexing."""
|
||||
FastUpdate = "fast-update"
|
||||
"""Incremental update with fast indexing."""
|
||||
|
||||
|
||||
class NounPhraseExtractorType(str, Enum):
|
||||
|
@ -58,9 +58,11 @@ models:
|
||||
### Input settings ###
|
||||
|
||||
input:
|
||||
type: {graphrag_config_defaults.input.type.value} # or blob
|
||||
storage:
|
||||
type: {graphrag_config_defaults.input.storage.type.value} # or blob
|
||||
base_dir: "{graphrag_config_defaults.input.storage.base_dir}"
|
||||
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
|
||||
base_dir: "{graphrag_config_defaults.input.base_dir}"
|
||||
|
||||
|
||||
chunks:
|
||||
size: {graphrag_config_defaults.chunks.size}
|
||||
|
@ -26,10 +26,10 @@ from graphrag.config.models.global_search_config import GlobalSearchConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.output_config import OutputConfig
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
from graphrag.config.models.snapshots_config import SnapshotsConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
@ -102,21 +102,31 @@ class GraphRagConfig(BaseModel):
|
||||
else:
|
||||
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"
|
||||
|
||||
def _validate_input_base_dir(self) -> None:
|
||||
"""Validate the input base directory."""
|
||||
if self.input.storage.type == defs.StorageType.file:
|
||||
if self.input.storage.base_dir.strip() == "":
|
||||
msg = "input storage base directory is required for file input storage. Please rerun `graphrag init` and set the input storage configuration."
|
||||
raise ValueError(msg)
|
||||
self.input.storage.base_dir = str(
|
||||
(Path(self.root_dir) / self.input.storage.base_dir).resolve()
|
||||
)
|
||||
|
||||
chunks: ChunkingConfig = Field(
|
||||
description="The chunking configuration to use.",
|
||||
default=ChunkingConfig(),
|
||||
)
|
||||
"""The chunking configuration to use."""
|
||||
|
||||
output: OutputConfig = Field(
|
||||
output: StorageConfig = Field(
|
||||
description="The output configuration.",
|
||||
default=OutputConfig(),
|
||||
default=StorageConfig(),
|
||||
)
|
||||
"""The output configuration."""
|
||||
|
||||
def _validate_output_base_dir(self) -> None:
|
||||
"""Validate the output base directory."""
|
||||
if self.output.type == defs.OutputType.file:
|
||||
if self.output.type == defs.StorageType.file:
|
||||
if self.output.base_dir.strip() == "":
|
||||
msg = "output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
|
||||
raise ValueError(msg)
|
||||
@ -124,7 +134,7 @@ class GraphRagConfig(BaseModel):
|
||||
(Path(self.root_dir) / self.output.base_dir).resolve()
|
||||
)
|
||||
|
||||
outputs: dict[str, OutputConfig] | None = Field(
|
||||
outputs: dict[str, StorageConfig] | None = Field(
|
||||
description="A list of output configurations used for multi-index query.",
|
||||
default=graphrag_config_defaults.outputs,
|
||||
)
|
||||
@ -133,7 +143,7 @@ class GraphRagConfig(BaseModel):
|
||||
"""Validate the outputs dict base directories."""
|
||||
if self.outputs:
|
||||
for output in self.outputs.values():
|
||||
if output.type == defs.OutputType.file:
|
||||
if output.type == defs.StorageType.file:
|
||||
if output.base_dir.strip() == "":
|
||||
msg = "Output base directory is required for file output. Please rerun `graphrag init` and set the output configuration."
|
||||
raise ValueError(msg)
|
||||
@ -141,10 +151,9 @@ class GraphRagConfig(BaseModel):
|
||||
(Path(self.root_dir) / output.base_dir).resolve()
|
||||
)
|
||||
|
||||
update_index_output: OutputConfig = Field(
|
||||
update_index_output: StorageConfig = Field(
|
||||
description="The output configuration for the updated index.",
|
||||
default=OutputConfig(
|
||||
type=graphrag_config_defaults.update_index_output.type,
|
||||
default=StorageConfig(
|
||||
base_dir=graphrag_config_defaults.update_index_output.base_dir,
|
||||
),
|
||||
)
|
||||
@ -152,7 +161,7 @@ class GraphRagConfig(BaseModel):
|
||||
|
||||
def _validate_update_index_output_base_dir(self) -> None:
|
||||
"""Validate the update index output base directory."""
|
||||
if self.update_index_output.type == defs.OutputType.file:
|
||||
if self.update_index_output.type == defs.StorageType.file:
|
||||
if self.update_index_output.base_dir.strip() == "":
|
||||
msg = "update_index_output base directory is required for file output. Please rerun `graphrag init` and set the update_index_output configuration."
|
||||
raise ValueError(msg)
|
||||
@ -345,6 +354,7 @@ class GraphRagConfig(BaseModel):
|
||||
self._validate_root_dir()
|
||||
self._validate_models()
|
||||
self._validate_input_pattern()
|
||||
self._validate_input_base_dir()
|
||||
self._validate_reporting_base_dir()
|
||||
self._validate_output_base_dir()
|
||||
self._validate_multi_output_base_dirs()
|
||||
|
@ -7,36 +7,23 @@ from pydantic import BaseModel, Field
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
|
||||
|
||||
class InputConfig(BaseModel):
|
||||
"""The default configuration section for Input."""
|
||||
|
||||
type: InputType = Field(
|
||||
description="The input type to use.",
|
||||
default=graphrag_config_defaults.input.type,
|
||||
storage: StorageConfig = Field(
|
||||
description="The storage configuration to use for reading input documents.",
|
||||
default=StorageConfig(
|
||||
base_dir=graphrag_config_defaults.input.storage.base_dir,
|
||||
),
|
||||
)
|
||||
file_type: InputFileType = Field(
|
||||
description="The input file type to use.",
|
||||
default=graphrag_config_defaults.input.file_type,
|
||||
)
|
||||
base_dir: str = Field(
|
||||
description="The input base directory to use.",
|
||||
default=graphrag_config_defaults.input.base_dir,
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
description="The azure blob storage connection string to use.",
|
||||
default=graphrag_config_defaults.input.connection_string,
|
||||
)
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url to use.",
|
||||
default=graphrag_config_defaults.input.storage_account_blob_url,
|
||||
)
|
||||
container_name: str | None = Field(
|
||||
description="The azure blob storage container name to use.",
|
||||
default=graphrag_config_defaults.input.container_name,
|
||||
)
|
||||
encoding: str = Field(
|
||||
description="The input file encoding to use.",
|
||||
default=defs.graphrag_config_defaults.input.encoding,
|
||||
|
@ -1,38 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import OutputType
|
||||
|
||||
|
||||
class OutputConfig(BaseModel):
|
||||
"""The default configuration section for Output."""
|
||||
|
||||
type: OutputType = Field(
|
||||
description="The output type to use.",
|
||||
default=graphrag_config_defaults.output.type,
|
||||
)
|
||||
base_dir: str = Field(
|
||||
description="The base directory for the output.",
|
||||
default=graphrag_config_defaults.output.base_dir,
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
description="The storage connection string to use.",
|
||||
default=graphrag_config_defaults.output.connection_string,
|
||||
)
|
||||
container_name: str | None = Field(
|
||||
description="The storage container name to use.",
|
||||
default=graphrag_config_defaults.output.container_name,
|
||||
)
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url to use.",
|
||||
default=graphrag_config_defaults.output.storage_account_blob_url,
|
||||
)
|
||||
cosmosdb_account_url: str | None = Field(
|
||||
description="The cosmosdb account url to use.",
|
||||
default=graphrag_config_defaults.output.cosmosdb_account_url,
|
||||
)
|
52
graphrag/config/models/storage_config.py
Normal file
52
graphrag/config/models/storage_config.py
Normal file
@ -0,0 +1,52 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Parameterization settings for the default configuration."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from graphrag.config.defaults import graphrag_config_defaults
|
||||
from graphrag.config.enums import StorageType
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
"""The default configuration section for storage."""
|
||||
|
||||
type: StorageType = Field(
|
||||
description="The storage type to use.",
|
||||
default=graphrag_config_defaults.storage.type,
|
||||
)
|
||||
base_dir: str = Field(
|
||||
description="The base directory for the output.",
|
||||
default=graphrag_config_defaults.storage.base_dir,
|
||||
)
|
||||
|
||||
# Validate the base dir for multiple OS (use Path)
|
||||
# if not using a cloud storage type.
|
||||
@field_validator("base_dir", mode="before")
|
||||
@classmethod
|
||||
def validate_base_dir(cls, value, info):
|
||||
"""Ensure that base_dir is a valid filesystem path when using local storage."""
|
||||
# info.data contains other field values, including 'type'
|
||||
if info.data.get("type") != StorageType.file:
|
||||
return value
|
||||
return str(Path(value))
|
||||
|
||||
connection_string: str | None = Field(
|
||||
description="The storage connection string to use.",
|
||||
default=graphrag_config_defaults.storage.connection_string,
|
||||
)
|
||||
container_name: str | None = Field(
|
||||
description="The storage container name to use.",
|
||||
default=graphrag_config_defaults.storage.container_name,
|
||||
)
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url to use.",
|
||||
default=graphrag_config_defaults.storage.storage_account_blob_url,
|
||||
)
|
||||
cosmosdb_account_url: str | None = Field(
|
||||
description="The cosmosdb account url to use.",
|
||||
default=graphrag_config_defaults.storage.cosmosdb_account_url,
|
||||
)
|
@ -22,7 +22,7 @@ async def load_csv(
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load csv inputs from a directory."""
|
||||
log.info("Loading csv files from %s", config.base_dir)
|
||||
log.info("Loading csv files from %s", config.storage.base_dir)
|
||||
|
||||
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
|
@ -5,20 +5,18 @@
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.csv import load_csv
|
||||
from graphrag.index.input.json import load_json
|
||||
from graphrag.index.input.text import load_text
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
@ -30,43 +28,12 @@ loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
|
||||
|
||||
async def create_input(
|
||||
config: InputConfig,
|
||||
storage: PipelineStorage,
|
||||
progress_reporter: ProgressLogger | None = None,
|
||||
root_dir: str | None = None,
|
||||
) -> pd.DataFrame:
|
||||
"""Instantiate input data for a pipeline."""
|
||||
root_dir = root_dir or ""
|
||||
log.info("loading input from root_dir=%s", config.base_dir)
|
||||
progress_reporter = progress_reporter or NullProgressLogger()
|
||||
|
||||
match config.type:
|
||||
case InputType.blob:
|
||||
log.info("using blob storage input")
|
||||
if config.container_name is None:
|
||||
msg = "Container name required for blob storage"
|
||||
raise ValueError(msg)
|
||||
if (
|
||||
config.connection_string is None
|
||||
and config.storage_account_blob_url is None
|
||||
):
|
||||
msg = "Connection string or storage account blob url required for blob storage"
|
||||
raise ValueError(msg)
|
||||
storage = BlobPipelineStorage(
|
||||
connection_string=config.connection_string,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
container_name=config.container_name,
|
||||
path_prefix=config.base_dir,
|
||||
)
|
||||
case InputType.file:
|
||||
log.info("using file storage for input")
|
||||
storage = FilePipelineStorage(
|
||||
root_dir=str(Path(root_dir) / (config.base_dir or ""))
|
||||
)
|
||||
case _:
|
||||
log.info("using file storage for input")
|
||||
storage = FilePipelineStorage(
|
||||
root_dir=str(Path(root_dir) / (config.base_dir or ""))
|
||||
)
|
||||
|
||||
if config.file_type in loaders:
|
||||
progress = progress_reporter.child(
|
||||
f"Loading Input ({config.file_type})", transient=False
|
||||
|
@ -22,7 +22,7 @@ async def load_json(
|
||||
storage: PipelineStorage,
|
||||
) -> pd.DataFrame:
|
||||
"""Load json inputs from a directory."""
|
||||
log.info("Loading json files from %s", config.base_dir)
|
||||
log.info("Loading json files from %s", config.storage.base_dir)
|
||||
|
||||
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
|
||||
if group is None:
|
||||
|
@ -33,7 +33,7 @@ async def load_files(
|
||||
)
|
||||
|
||||
if len(files) == 0:
|
||||
msg = f"No {config.file_type} files found in {config.base_dir}"
|
||||
msg = f"No {config.file_type} files found in {config.storage.base_dir}"
|
||||
raise ValueError(msg)
|
||||
|
||||
files_loaded = []
|
||||
|
@ -11,16 +11,12 @@ import traceback
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import asdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.pipeline import Pipeline
|
||||
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.update.incremental_index import get_delta_docs
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.progress import Progress
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
@ -40,86 +36,72 @@ async def run_pipeline(
|
||||
"""Run all workflows using a simplified pipeline."""
|
||||
root_dir = config.root_dir
|
||||
|
||||
storage = create_storage_from_config(config.output)
|
||||
input_storage = create_storage_from_config(config.input.storage)
|
||||
output_storage = create_storage_from_config(config.output)
|
||||
cache = create_cache_from_config(config.cache, root_dir)
|
||||
|
||||
dataset = await create_input(config.input, logger, root_dir)
|
||||
|
||||
# load existing state in case any workflows are stateful
|
||||
state_json = await storage.get("context.json")
|
||||
state_json = await output_storage.get("context.json")
|
||||
state = json.loads(state_json) if state_json else {}
|
||||
|
||||
if is_update_run:
|
||||
logger.info("Running incremental indexing.")
|
||||
|
||||
delta_dataset = await get_delta_docs(dataset, storage)
|
||||
update_storage = create_storage_from_config(config.update_index_output)
|
||||
# we use this to store the new subset index, and will merge its content with the previous index
|
||||
update_timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
timestamped_storage = update_storage.child(update_timestamp)
|
||||
delta_storage = timestamped_storage.child("delta")
|
||||
# copy the previous output to a backup folder, so we can replace it with the update
|
||||
# we'll read from this later when we merge the old and new indexes
|
||||
previous_storage = timestamped_storage.child("previous")
|
||||
await _copy_previous_output(output_storage, previous_storage)
|
||||
|
||||
# warn on empty delta dataset
|
||||
if delta_dataset.new_inputs.empty:
|
||||
warning_msg = "Incremental indexing found no new documents, exiting."
|
||||
logger.warning(warning_msg)
|
||||
else:
|
||||
update_storage = create_storage_from_config(config.update_index_output)
|
||||
# we use this to store the new subset index, and will merge its content with the previous index
|
||||
update_timestamp = time.strftime("%Y%m%d-%H%M%S")
|
||||
timestamped_storage = update_storage.child(update_timestamp)
|
||||
delta_storage = timestamped_storage.child("delta")
|
||||
# copy the previous output to a backup folder, so we can replace it with the update
|
||||
# we'll read from this later when we merge the old and new indexes
|
||||
previous_storage = timestamped_storage.child("previous")
|
||||
await _copy_previous_output(storage, previous_storage)
|
||||
state["update_timestamp"] = update_timestamp
|
||||
|
||||
state["update_timestamp"] = update_timestamp
|
||||
|
||||
context = create_run_context(
|
||||
storage=delta_storage, cache=cache, callbacks=callbacks, state=state
|
||||
)
|
||||
|
||||
# Run the pipeline on the new documents
|
||||
async for table in _run_pipeline(
|
||||
pipeline=pipeline,
|
||||
config=config,
|
||||
dataset=delta_dataset.new_inputs,
|
||||
logger=logger,
|
||||
context=context,
|
||||
):
|
||||
yield table
|
||||
|
||||
logger.success("Finished running workflows on new documents.")
|
||||
context = create_run_context(
|
||||
input_storage=input_storage,
|
||||
output_storage=delta_storage,
|
||||
previous_storage=previous_storage,
|
||||
cache=cache,
|
||||
callbacks=callbacks,
|
||||
state=state,
|
||||
progress_logger=logger,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.info("Running standard indexing.")
|
||||
|
||||
context = create_run_context(
|
||||
storage=storage, cache=cache, callbacks=callbacks, state=state
|
||||
input_storage=input_storage,
|
||||
output_storage=output_storage,
|
||||
cache=cache,
|
||||
callbacks=callbacks,
|
||||
state=state,
|
||||
progress_logger=logger,
|
||||
)
|
||||
|
||||
async for table in _run_pipeline(
|
||||
pipeline=pipeline,
|
||||
config=config,
|
||||
dataset=dataset,
|
||||
logger=logger,
|
||||
context=context,
|
||||
):
|
||||
yield table
|
||||
async for table in _run_pipeline(
|
||||
pipeline=pipeline,
|
||||
config=config,
|
||||
logger=logger,
|
||||
context=context,
|
||||
):
|
||||
yield table
|
||||
|
||||
|
||||
async def _run_pipeline(
|
||||
pipeline: Pipeline,
|
||||
config: GraphRagConfig,
|
||||
dataset: pd.DataFrame,
|
||||
logger: ProgressLogger,
|
||||
context: PipelineRunContext,
|
||||
) -> AsyncIterable[PipelineRunResult]:
|
||||
start_time = time.time()
|
||||
|
||||
log.info("Final # of rows loaded: %s", len(dataset))
|
||||
context.stats.num_documents = len(dataset)
|
||||
last_workflow = "starting documents"
|
||||
last_workflow = "<startup>"
|
||||
|
||||
try:
|
||||
await _dump_json(context)
|
||||
await write_table_to_storage(dataset, "documents", context.storage)
|
||||
|
||||
for name, workflow_function in pipeline.run():
|
||||
last_workflow = name
|
||||
@ -132,8 +114,10 @@ async def _run_pipeline(
|
||||
yield PipelineRunResult(
|
||||
workflow=name, result=result.result, state=context.state, errors=None
|
||||
)
|
||||
|
||||
context.stats.workflows[name] = {"overall": time.time() - work_time}
|
||||
if result.stop:
|
||||
logger.info("Halting pipeline at workflow request")
|
||||
break
|
||||
|
||||
context.stats.total_runtime = time.time() - start_time
|
||||
await _dump_json(context)
|
||||
@ -148,10 +132,10 @@ async def _run_pipeline(
|
||||
|
||||
async def _dump_json(context: PipelineRunContext) -> None:
|
||||
"""Dump the stats and context state to the storage."""
|
||||
await context.storage.set(
|
||||
await context.output_storage.set(
|
||||
"stats.json", json.dumps(asdict(context.stats), indent=4, ensure_ascii=False)
|
||||
)
|
||||
await context.storage.set(
|
||||
await context.output_storage.set(
|
||||
"context.json", json.dumps(context.state, indent=4, ensure_ascii=False)
|
||||
)
|
||||
|
||||
|
@ -14,24 +14,31 @@ from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.state import PipelineState
|
||||
from graphrag.index.typing.stats import PipelineRunStats
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
|
||||
|
||||
def create_run_context(
|
||||
storage: PipelineStorage | None = None,
|
||||
input_storage: PipelineStorage | None = None,
|
||||
output_storage: PipelineStorage | None = None,
|
||||
previous_storage: PipelineStorage | None = None,
|
||||
cache: PipelineCache | None = None,
|
||||
callbacks: WorkflowCallbacks | None = None,
|
||||
progress_logger: ProgressLogger | None = None,
|
||||
stats: PipelineRunStats | None = None,
|
||||
state: PipelineState | None = None,
|
||||
) -> PipelineRunContext:
|
||||
"""Create the run context for the pipeline."""
|
||||
return PipelineRunContext(
|
||||
stats=stats or PipelineRunStats(),
|
||||
input_storage=input_storage or MemoryPipelineStorage(),
|
||||
output_storage=output_storage or MemoryPipelineStorage(),
|
||||
previous_storage=previous_storage or MemoryPipelineStorage(),
|
||||
cache=cache or InMemoryCache(),
|
||||
storage=storage or MemoryPipelineStorage(),
|
||||
callbacks=callbacks or NoopWorkflowCallbacks(),
|
||||
progress_logger=progress_logger or NullProgressLogger(),
|
||||
stats=stats or PipelineRunStats(),
|
||||
state=state or {},
|
||||
)
|
||||
|
||||
|
@ -10,6 +10,7 @@ from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.typing.state import PipelineState
|
||||
from graphrag.index.typing.stats import PipelineRunStats
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
|
||||
|
||||
@ -18,11 +19,17 @@ class PipelineRunContext:
|
||||
"""Provides the context for the current pipeline run."""
|
||||
|
||||
stats: PipelineRunStats
|
||||
storage: PipelineStorage
|
||||
input_storage: PipelineStorage
|
||||
"Storage for input documents."
|
||||
output_storage: PipelineStorage
|
||||
"Long-term storage for pipeline verbs to use. Items written here will be written to the storage provider."
|
||||
previous_storage: PipelineStorage
|
||||
"Storage for previous pipeline run when running in update mode."
|
||||
cache: PipelineCache
|
||||
"Cache instance for reading previous LLM responses."
|
||||
callbacks: WorkflowCallbacks
|
||||
"Callbacks to be called during the pipeline run."
|
||||
progress_logger: ProgressLogger
|
||||
"Progress logger for the pipeline run."
|
||||
state: PipelineState
|
||||
"Arbitrary property bag for runtime state, persistent pre-computes, or experimental features."
|
||||
|
@ -15,6 +15,8 @@ class PipelineRunStats:
|
||||
|
||||
num_documents: int = field(default=0)
|
||||
"""Number of documents."""
|
||||
update_documents: int = field(default=0)
|
||||
"""Number of update documents."""
|
||||
|
||||
input_load_time: float = field(default=0)
|
||||
"""Float representing the input load time."""
|
||||
|
@ -17,6 +17,8 @@ class WorkflowFunctionOutput:
|
||||
|
||||
result: Any | None
|
||||
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
|
||||
stop: bool = False
|
||||
"""Flag to indicate if the workflow should stop after this function. This should only be used when continuation could cause an unstable failure."""
|
||||
|
||||
|
||||
WorkflowFunction = Callable[
|
||||
|
@ -39,6 +39,12 @@ from .finalize_graph import (
|
||||
from .generate_text_embeddings import (
|
||||
run_workflow as run_generate_text_embeddings,
|
||||
)
|
||||
from .load_input_documents import (
|
||||
run_workflow as run_load_input_documents,
|
||||
)
|
||||
from .load_update_documents import (
|
||||
run_workflow as run_load_update_documents,
|
||||
)
|
||||
from .prune_graph import (
|
||||
run_workflow as run_prune_graph,
|
||||
)
|
||||
@ -69,6 +75,8 @@ from .update_text_units import (
|
||||
|
||||
# register all of our built-in workflows at once
|
||||
PipelineFactory.register_all({
|
||||
"load_input_documents": run_load_input_documents,
|
||||
"load_update_documents": run_load_update_documents,
|
||||
"create_base_text_units": run_create_base_text_units,
|
||||
"create_communities": run_create_communities,
|
||||
"create_community_reports_text": run_create_community_reports_text,
|
||||
|
@ -25,7 +25,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform base text_units."""
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
chunks = config.chunks
|
||||
|
||||
@ -41,7 +41,7 @@ async def run_workflow(
|
||||
chunk_size_includes_metadata=chunks.chunk_size_includes_metadata,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.storage)
|
||||
await write_table_to_storage(output, "text_units", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
@ -24,8 +24,10 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform final communities."""
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
relationships = await load_table_from_storage("relationships", context.storage)
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
|
||||
max_cluster_size = config.cluster_graph.max_cluster_size
|
||||
use_lcc = config.cluster_graph.use_lcc
|
||||
@ -39,7 +41,7 @@ async def run_workflow(
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "communities", context.storage)
|
||||
await write_table_to_storage(output, "communities", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
@ -38,14 +38,14 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
edges = await load_table_from_storage("relationships", context.storage)
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
communities = await load_table_from_storage("communities", context.storage)
|
||||
edges = await load_table_from_storage("relationships", context.output_storage)
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
communities = await load_table_from_storage("communities", context.output_storage)
|
||||
claims = None
|
||||
if config.extract_claims.enabled and await storage_has_table(
|
||||
"covariates", context.storage
|
||||
"covariates", context.output_storage
|
||||
):
|
||||
claims = await load_table_from_storage("covariates", context.storage)
|
||||
claims = await load_table_from_storage("covariates", context.output_storage)
|
||||
|
||||
community_reports_llm_settings = config.get_language_model_config(
|
||||
config.community_reports.model_id
|
||||
@ -68,7 +68,7 @@ async def run_workflow(
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "community_reports", context.storage)
|
||||
await write_table_to_storage(output, "community_reports", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
@ -37,10 +37,10 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform community reports."""
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
communities = await load_table_from_storage("communities", context.storage)
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
communities = await load_table_from_storage("communities", context.output_storage)
|
||||
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
community_reports_llm_settings = config.get_language_model_config(
|
||||
config.community_reports.model_id
|
||||
@ -62,7 +62,7 @@ async def run_workflow(
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "community_reports", context.storage)
|
||||
await write_table_to_storage(output, "community_reports", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
@ -17,12 +17,12 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform final documents."""
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
output = create_final_documents(documents, text_units)
|
||||
|
||||
await write_table_to_storage(output, "documents", context.storage)
|
||||
await write_table_to_storage(output, "documents", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
@ -21,16 +21,18 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to transform the text units."""
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
final_entities = await load_table_from_storage("entities", context.storage)
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
final_entities = await load_table_from_storage("entities", context.output_storage)
|
||||
final_relationships = await load_table_from_storage(
|
||||
"relationships", context.storage
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
final_covariates = None
|
||||
if config.extract_claims.enabled and await storage_has_table(
|
||||
"covariates", context.storage
|
||||
"covariates", context.output_storage
|
||||
):
|
||||
final_covariates = await load_table_from_storage("covariates", context.storage)
|
||||
final_covariates = await load_table_from_storage(
|
||||
"covariates", context.output_storage
|
||||
)
|
||||
|
||||
output = create_final_text_units(
|
||||
text_units,
|
||||
@ -39,7 +41,7 @@ async def run_workflow(
|
||||
final_covariates,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "text_units", context.storage)
|
||||
await write_table_to_storage(output, "text_units", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
@ -26,30 +26,32 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to extract and format covariates."""
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
output = None
|
||||
if config.extract_claims.enabled:
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
extract_claims_llm_settings = config.get_language_model_config(
|
||||
config.extract_claims.model_id
|
||||
)
|
||||
extraction_strategy = config.extract_claims.resolved_strategy(
|
||||
config.root_dir, extract_claims_llm_settings
|
||||
)
|
||||
extract_claims_llm_settings = config.get_language_model_config(
|
||||
config.extract_claims.model_id
|
||||
)
|
||||
extraction_strategy = config.extract_claims.resolved_strategy(
|
||||
config.root_dir, extract_claims_llm_settings
|
||||
)
|
||||
|
||||
async_mode = extract_claims_llm_settings.async_mode
|
||||
num_threads = extract_claims_llm_settings.concurrent_requests
|
||||
async_mode = extract_claims_llm_settings.async_mode
|
||||
num_threads = extract_claims_llm_settings.concurrent_requests
|
||||
|
||||
output = await extract_covariates(
|
||||
text_units,
|
||||
context.callbacks,
|
||||
context.cache,
|
||||
"claim",
|
||||
extraction_strategy,
|
||||
async_mode=async_mode,
|
||||
entity_types=None,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
output = await extract_covariates(
|
||||
text_units,
|
||||
context.callbacks,
|
||||
context.cache,
|
||||
"claim",
|
||||
extraction_strategy,
|
||||
async_mode=async_mode,
|
||||
entity_types=None,
|
||||
num_threads=num_threads,
|
||||
)
|
||||
|
||||
await write_table_to_storage(output, "covariates", context.storage)
|
||||
await write_table_to_storage(output, "covariates", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
@ -27,7 +27,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
extract_graph_llm_settings = config.get_language_model_config(
|
||||
config.extract_graph.model_id
|
||||
@ -55,13 +55,15 @@ async def run_workflow(
|
||||
summarization_num_threads=summarization_llm_settings.concurrent_requests,
|
||||
)
|
||||
|
||||
await write_table_to_storage(entities, "entities", context.storage)
|
||||
await write_table_to_storage(relationships, "relationships", context.storage)
|
||||
await write_table_to_storage(entities, "entities", context.output_storage)
|
||||
await write_table_to_storage(relationships, "relationships", context.output_storage)
|
||||
|
||||
if config.snapshots.raw_graph:
|
||||
await write_table_to_storage(raw_entities, "raw_entities", context.storage)
|
||||
await write_table_to_storage(
|
||||
raw_relationships, "raw_relationships", context.storage
|
||||
raw_entities, "raw_entities", context.output_storage
|
||||
)
|
||||
await write_table_to_storage(
|
||||
raw_relationships, "raw_relationships", context.output_storage
|
||||
)
|
||||
|
||||
return WorkflowFunctionOutput(
|
||||
|
@ -22,7 +22,7 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
entities, relationships = await extract_graph_nlp(
|
||||
text_units,
|
||||
@ -30,8 +30,8 @@ async def run_workflow(
|
||||
extraction_config=config.extract_graph_nlp,
|
||||
)
|
||||
|
||||
await write_table_to_storage(entities, "entities", context.storage)
|
||||
await write_table_to_storage(relationships, "relationships", context.storage)
|
||||
await write_table_to_storage(entities, "entities", context.output_storage)
|
||||
await write_table_to_storage(relationships, "relationships", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(
|
||||
result={
|
||||
|
@ -15,6 +15,7 @@ class PipelineFactory:
|
||||
"""A factory class for workflow pipelines."""
|
||||
|
||||
workflows: ClassVar[dict[str, WorkflowFunction]] = {}
|
||||
pipelines: ClassVar[dict[str, list[str]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, workflow: WorkflowFunction):
|
||||
@ -27,61 +28,66 @@ class PipelineFactory:
|
||||
for name, workflow in workflows.items():
|
||||
cls.register(name, workflow)
|
||||
|
||||
@classmethod
|
||||
def register_pipeline(cls, name: str, workflows: list[str]):
|
||||
"""Register a new pipeline method as a list of workflow names."""
|
||||
cls.pipelines[name] = workflows
|
||||
|
||||
@classmethod
|
||||
def create_pipeline(
|
||||
cls,
|
||||
config: GraphRagConfig,
|
||||
method: IndexingMethod = IndexingMethod.Standard,
|
||||
is_update_run: bool = False,
|
||||
method: IndexingMethod | str = IndexingMethod.Standard,
|
||||
) -> Pipeline:
|
||||
"""Create a pipeline generator."""
|
||||
workflows = _get_workflows_list(config, method, is_update_run)
|
||||
workflows = config.workflows or cls.pipelines.get(method, [])
|
||||
return Pipeline([(name, cls.workflows[name]) for name in workflows])
|
||||
|
||||
|
||||
def _get_workflows_list(
|
||||
config: GraphRagConfig,
|
||||
method: IndexingMethod = IndexingMethod.Standard,
|
||||
is_update_run: bool = False,
|
||||
) -> list[str]:
|
||||
"""Return a list of workflows for the indexing pipeline."""
|
||||
update_workflows = [
|
||||
"update_final_documents",
|
||||
"update_entities_relationships",
|
||||
"update_text_units",
|
||||
"update_covariates",
|
||||
"update_communities",
|
||||
"update_community_reports",
|
||||
"update_text_embeddings",
|
||||
"update_clean_state",
|
||||
]
|
||||
if config.workflows:
|
||||
return config.workflows
|
||||
|
||||
match method:
|
||||
case IndexingMethod.Standard:
|
||||
return [
|
||||
"create_base_text_units",
|
||||
"create_final_documents",
|
||||
"extract_graph",
|
||||
"finalize_graph",
|
||||
*(["extract_covariates"] if config.extract_claims.enabled else []),
|
||||
"create_communities",
|
||||
"create_final_text_units",
|
||||
"create_community_reports",
|
||||
"generate_text_embeddings",
|
||||
*(update_workflows if is_update_run else []),
|
||||
]
|
||||
case IndexingMethod.Fast:
|
||||
return [
|
||||
"create_base_text_units",
|
||||
"create_final_documents",
|
||||
"extract_graph_nlp",
|
||||
"prune_graph",
|
||||
"finalize_graph",
|
||||
"create_communities",
|
||||
"create_final_text_units",
|
||||
"create_community_reports_text",
|
||||
"generate_text_embeddings",
|
||||
*(update_workflows if is_update_run else []),
|
||||
]
|
||||
# --- Register default implementations ---
|
||||
_standard_workflows = [
|
||||
"create_base_text_units",
|
||||
"create_final_documents",
|
||||
"extract_graph",
|
||||
"finalize_graph",
|
||||
"extract_covariates",
|
||||
"create_communities",
|
||||
"create_final_text_units",
|
||||
"create_community_reports",
|
||||
"generate_text_embeddings",
|
||||
]
|
||||
_fast_workflows = [
|
||||
"create_base_text_units",
|
||||
"create_final_documents",
|
||||
"extract_graph_nlp",
|
||||
"prune_graph",
|
||||
"finalize_graph",
|
||||
"create_communities",
|
||||
"create_final_text_units",
|
||||
"create_community_reports_text",
|
||||
"generate_text_embeddings",
|
||||
]
|
||||
_update_workflows = [
|
||||
"update_final_documents",
|
||||
"update_entities_relationships",
|
||||
"update_text_units",
|
||||
"update_covariates",
|
||||
"update_communities",
|
||||
"update_community_reports",
|
||||
"update_text_embeddings",
|
||||
"update_clean_state",
|
||||
]
|
||||
PipelineFactory.register_pipeline(
|
||||
IndexingMethod.Standard, ["load_input_documents", *_standard_workflows]
|
||||
)
|
||||
PipelineFactory.register_pipeline(
|
||||
IndexingMethod.Fast, ["load_input_documents", *_fast_workflows]
|
||||
)
|
||||
PipelineFactory.register_pipeline(
|
||||
IndexingMethod.StandardUpdate,
|
||||
["load_update_documents", *_standard_workflows, *_update_workflows],
|
||||
)
|
||||
PipelineFactory.register_pipeline(
|
||||
IndexingMethod.FastUpdate,
|
||||
["load_update_documents", *_fast_workflows, *_update_workflows],
|
||||
)
|
||||
|
@ -22,8 +22,10 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
relationships = await load_table_from_storage("relationships", context.storage)
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
|
||||
final_entities, final_relationships = finalize_graph(
|
||||
entities,
|
||||
@ -33,8 +35,10 @@ async def run_workflow(
|
||||
layout_enabled=config.umap.enabled,
|
||||
)
|
||||
|
||||
await write_table_to_storage(final_entities, "entities", context.storage)
|
||||
await write_table_to_storage(final_relationships, "relationships", context.storage)
|
||||
await write_table_to_storage(final_entities, "entities", context.output_storage)
|
||||
await write_table_to_storage(
|
||||
final_relationships, "relationships", context.output_storage
|
||||
)
|
||||
|
||||
if config.snapshots.graphml:
|
||||
# todo: extract graphs at each level, and add in meta like descriptions
|
||||
@ -43,7 +47,7 @@ async def run_workflow(
|
||||
await snapshot_graphml(
|
||||
graph,
|
||||
name="graph",
|
||||
storage=context.storage,
|
||||
storage=context.output_storage,
|
||||
)
|
||||
|
||||
return WorkflowFunctionOutput(
|
||||
|
@ -43,17 +43,19 @@ async def run_workflow(
|
||||
text_units = None
|
||||
entities = None
|
||||
community_reports = None
|
||||
if await storage_has_table("documents", context.storage):
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
if await storage_has_table("relationships", context.storage):
|
||||
relationships = await load_table_from_storage("relationships", context.storage)
|
||||
if await storage_has_table("text_units", context.storage):
|
||||
text_units = await load_table_from_storage("text_units", context.storage)
|
||||
if await storage_has_table("entities", context.storage):
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
if await storage_has_table("community_reports", context.storage):
|
||||
if await storage_has_table("documents", context.output_storage):
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
if await storage_has_table("relationships", context.output_storage):
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
if await storage_has_table("text_units", context.output_storage):
|
||||
text_units = await load_table_from_storage("text_units", context.output_storage)
|
||||
if await storage_has_table("entities", context.output_storage):
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
if await storage_has_table("community_reports", context.output_storage):
|
||||
community_reports = await load_table_from_storage(
|
||||
"community_reports", context.storage
|
||||
"community_reports", context.output_storage
|
||||
)
|
||||
|
||||
embedded_fields = config.embed_text.names
|
||||
@ -76,7 +78,7 @@ async def run_workflow(
|
||||
await write_table_to_storage(
|
||||
table,
|
||||
f"embeddings.{name}",
|
||||
context.storage,
|
||||
context.output_storage,
|
||||
)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
45
graphrag/index/workflows/load_input_documents.py
Normal file
45
graphrag/index/workflows/load_input_documents.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.storage import write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Load and parse input documents into a standard format."""
|
||||
output = await load_input_documents(
|
||||
config.input,
|
||||
context.input_storage,
|
||||
context.progress_logger,
|
||||
)
|
||||
|
||||
log.info("Final # of rows loaded: %s", len(output))
|
||||
context.stats.num_documents = len(output)
|
||||
|
||||
await write_table_to_storage(output, "documents", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
async def load_input_documents(
|
||||
config: InputConfig, storage: PipelineStorage, progress_logger: ProgressLogger
|
||||
) -> pd.DataFrame:
|
||||
"""Load and parse input documents into a standard format."""
|
||||
return await create_input(config, storage, progress_logger)
|
59
graphrag/index/workflows/load_update_documents.py
Normal file
59
graphrag/index/workflows/load_update_documents.py
Normal file
@ -0,0 +1,59 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing run_workflow method definition."""
|
||||
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.typing.context import PipelineRunContext
|
||||
from graphrag.index.typing.workflow import WorkflowFunctionOutput
|
||||
from graphrag.index.update.incremental_index import get_delta_docs
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
from graphrag.utils.storage import write_table_to_storage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""Load and parse update-only input documents into a standard format."""
|
||||
output = await load_update_documents(
|
||||
config.input,
|
||||
context.input_storage,
|
||||
context.previous_storage,
|
||||
context.progress_logger,
|
||||
)
|
||||
|
||||
log.info("Final # of update rows loaded: %s", len(output))
|
||||
context.stats.update_documents = len(output)
|
||||
|
||||
if len(output) == 0:
|
||||
log.warning("No new update documents found.")
|
||||
context.progress_logger.warning("No new update documents found.")
|
||||
return WorkflowFunctionOutput(result=None, stop=True)
|
||||
|
||||
await write_table_to_storage(output, "documents", context.output_storage)
|
||||
|
||||
return WorkflowFunctionOutput(result=output)
|
||||
|
||||
|
||||
async def load_update_documents(
|
||||
config: InputConfig,
|
||||
input_storage: PipelineStorage,
|
||||
previous_storage: PipelineStorage,
|
||||
progress_logger: ProgressLogger,
|
||||
) -> pd.DataFrame:
|
||||
"""Load and parse update-only input documents into a standard format."""
|
||||
input_documents = await create_input(config, input_storage, progress_logger)
|
||||
# previous storage is the output of the previous run
|
||||
# we'll use this to diff the input from the prior
|
||||
delta_documents = await get_delta_docs(input_documents, previous_storage)
|
||||
return delta_documents.new_inputs
|
@ -20,8 +20,10 @@ async def run_workflow(
|
||||
context: PipelineRunContext,
|
||||
) -> WorkflowFunctionOutput:
|
||||
"""All the steps to create the base entity graph."""
|
||||
entities = await load_table_from_storage("entities", context.storage)
|
||||
relationships = await load_table_from_storage("relationships", context.storage)
|
||||
entities = await load_table_from_storage("entities", context.output_storage)
|
||||
relationships = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
|
||||
pruned_entities, pruned_relationships = prune_graph(
|
||||
entities,
|
||||
@ -29,8 +31,10 @@ async def run_workflow(
|
||||
pruning_config=config.prune_graph,
|
||||
)
|
||||
|
||||
await write_table_to_storage(pruned_entities, "entities", context.storage)
|
||||
await write_table_to_storage(pruned_relationships, "relationships", context.storage)
|
||||
await write_table_to_storage(pruned_entities, "entities", context.output_storage)
|
||||
await write_table_to_storage(
|
||||
pruned_relationships, "relationships", context.output_storage
|
||||
)
|
||||
|
||||
return WorkflowFunctionOutput(
|
||||
result={
|
||||
|
@ -21,6 +21,7 @@ from graphrag.prompt_tune.defaults import (
|
||||
K,
|
||||
)
|
||||
from graphrag.prompt_tune.types import DocSelectionType
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
|
||||
|
||||
def _sample_chunks_from_embeddings(
|
||||
@ -37,7 +38,6 @@ def _sample_chunks_from_embeddings(
|
||||
|
||||
|
||||
async def load_docs_in_chunks(
|
||||
root: str,
|
||||
config: GraphRagConfig,
|
||||
select_method: DocSelectionType,
|
||||
limit: int,
|
||||
@ -51,7 +51,8 @@ async def load_docs_in_chunks(
|
||||
embeddings_llm_settings = config.get_language_model_config(
|
||||
config.embed_text.model_id
|
||||
)
|
||||
dataset = await create_input(config.input, logger, root)
|
||||
input_storage = create_storage_from_config(config.input.storage)
|
||||
dataset = await create_input(config.input, input_storage, logger)
|
||||
chunk_config = config.chunks
|
||||
chunks_df = create_base_text_units(
|
||||
documents=dataset,
|
||||
|
@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from graphrag.config.enums import OutputType
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.storage.blob_pipeline_storage import create_blob_storage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage
|
||||
from graphrag.storage.file_pipeline_storage import create_file_storage
|
||||
@ -35,17 +35,17 @@ class StorageFactory:
|
||||
|
||||
@classmethod
|
||||
def create_storage(
|
||||
cls, storage_type: OutputType | str, kwargs: dict
|
||||
cls, storage_type: StorageType | str, kwargs: dict
|
||||
) -> PipelineStorage:
|
||||
"""Create or get a storage object from the provided type."""
|
||||
match storage_type:
|
||||
case OutputType.blob:
|
||||
case StorageType.blob:
|
||||
return create_blob_storage(**kwargs)
|
||||
case OutputType.cosmosdb:
|
||||
case StorageType.cosmosdb:
|
||||
return create_cosmosdb_storage(**kwargs)
|
||||
case OutputType.file:
|
||||
case StorageType.file:
|
||||
return create_file_storage(**kwargs)
|
||||
case OutputType.memory:
|
||||
case StorageType.memory:
|
||||
return MemoryPipelineStorage()
|
||||
case _:
|
||||
if storage_type in cls.storage_types:
|
||||
|
@ -10,7 +10,7 @@ from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.config.embeddings import create_collection_name
|
||||
from graphrag.config.models.cache_config import CacheConfig
|
||||
from graphrag.config.models.output_config import OutputConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.data_model.types import TextEmbedder
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
@ -238,7 +238,7 @@ def load_search_prompt(root_dir: str, prompt_config: str | None) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def create_storage_from_config(output: OutputConfig) -> PipelineStorage:
|
||||
def create_storage_from_config(output: StorageConfig) -> PipelineStorage:
|
||||
"""Create a storage object from the config."""
|
||||
storage_config = output.model_dump()
|
||||
return StorageFactory().create_storage(
|
||||
|
12
tests/fixtures/azure/settings.yml
vendored
12
tests/fixtures/azure/settings.yml
vendored
@ -9,11 +9,13 @@ vector_store:
|
||||
container_name: "azure_ci"
|
||||
|
||||
input:
|
||||
type: blob
|
||||
storage:
|
||||
type: blob
|
||||
connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING}
|
||||
container_name: azurefixture
|
||||
base_dir: input
|
||||
file_type: text
|
||||
connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING}
|
||||
container_name: azurefixture
|
||||
base_dir: input
|
||||
|
||||
|
||||
cache:
|
||||
type: blob
|
||||
@ -21,7 +23,7 @@ cache:
|
||||
container_name: cicache
|
||||
base_dir: cache_azure_ai
|
||||
|
||||
storage:
|
||||
output:
|
||||
type: blob
|
||||
connection_string: ${LOCAL_BLOB_STORAGE_CONNECTION_STRING}
|
||||
container_name: azurefixture
|
||||
|
6
tests/fixtures/min-csv/config.json
vendored
6
tests/fixtures/min-csv/config.json
vendored
@ -2,9 +2,15 @@
|
||||
"input_path": "./tests/fixtures/min-csv",
|
||||
"input_file_type": "text",
|
||||
"workflow_config": {
|
||||
"load_input_documents": {
|
||||
"max_runtime": 30
|
||||
},
|
||||
"create_base_text_units": {
|
||||
"max_runtime": 30
|
||||
},
|
||||
"extract_covariates": {
|
||||
"max_runtime": 10
|
||||
},
|
||||
"extract_graph": {
|
||||
"max_runtime": 500
|
||||
},
|
||||
|
3
tests/fixtures/text/config.json
vendored
3
tests/fixtures/text/config.json
vendored
@ -2,6 +2,9 @@
|
||||
"input_path": "./tests/fixtures/text",
|
||||
"input_file_type": "text",
|
||||
"workflow_config": {
|
||||
"load_input_documents": {
|
||||
"max_runtime": 30
|
||||
},
|
||||
"create_base_text_units": {
|
||||
"max_runtime": 30
|
||||
},
|
||||
|
@ -9,7 +9,7 @@ import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.config.enums import OutputType
|
||||
from graphrag.config.enums import StorageType
|
||||
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
||||
from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
@ -29,7 +29,7 @@ def test_create_blob_storage():
|
||||
"base_dir": "testbasedir",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
storage = StorageFactory.create_storage(OutputType.blob, kwargs)
|
||||
storage = StorageFactory.create_storage(StorageType.blob, kwargs)
|
||||
assert isinstance(storage, BlobPipelineStorage)
|
||||
|
||||
|
||||
@ -44,19 +44,19 @@ def test_create_cosmosdb_storage():
|
||||
"base_dir": "testdatabase",
|
||||
"container_name": "testcontainer",
|
||||
}
|
||||
storage = StorageFactory.create_storage(OutputType.cosmosdb, kwargs)
|
||||
storage = StorageFactory.create_storage(StorageType.cosmosdb, kwargs)
|
||||
assert isinstance(storage, CosmosDBPipelineStorage)
|
||||
|
||||
|
||||
def test_create_file_storage():
|
||||
kwargs = {"type": "file", "base_dir": "/tmp/teststorage"}
|
||||
storage = StorageFactory.create_storage(OutputType.file, kwargs)
|
||||
storage = StorageFactory.create_storage(StorageType.file, kwargs)
|
||||
assert isinstance(storage, FilePipelineStorage)
|
||||
|
||||
|
||||
def test_create_memory_storage():
|
||||
kwargs = {"type": "memory"}
|
||||
storage = StorageFactory.create_storage(OutputType.memory, kwargs)
|
||||
storage = StorageFactory.create_storage(StorageType.memory, kwargs)
|
||||
assert isinstance(storage, MemoryPipelineStorage)
|
||||
|
||||
|
||||
|
@ -24,10 +24,10 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.language_model_config import LanguageModelConfig
|
||||
from graphrag.config.models.local_search_config import LocalSearchConfig
|
||||
from graphrag.config.models.output_config import OutputConfig
|
||||
from graphrag.config.models.prune_graph_config import PruneGraphConfig
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
from graphrag.config.models.snapshots_config import SnapshotsConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.config.models.summarize_descriptions_config import (
|
||||
SummarizeDescriptionsConfig,
|
||||
)
|
||||
@ -134,7 +134,7 @@ def assert_reporting_configs(
|
||||
assert actual.storage_account_blob_url == expected.storage_account_blob_url
|
||||
|
||||
|
||||
def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None:
|
||||
def assert_output_configs(actual: StorageConfig, expected: StorageConfig) -> None:
|
||||
assert expected.type == actual.type
|
||||
assert expected.base_dir == actual.base_dir
|
||||
assert expected.connection_string == actual.connection_string
|
||||
@ -143,7 +143,9 @@ def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None:
|
||||
assert expected.cosmosdb_account_url == actual.cosmosdb_account_url
|
||||
|
||||
|
||||
def assert_update_output_configs(actual: OutputConfig, expected: OutputConfig) -> None:
|
||||
def assert_update_output_configs(
|
||||
actual: StorageConfig, expected: StorageConfig
|
||||
) -> None:
|
||||
assert expected.type == actual.type
|
||||
assert expected.base_dir == actual.base_dir
|
||||
assert expected.connection_string == actual.connection_string
|
||||
@ -162,12 +164,15 @@ def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None:
|
||||
|
||||
|
||||
def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None:
|
||||
assert actual.type == expected.type
|
||||
assert actual.storage.type == expected.storage.type
|
||||
assert actual.file_type == expected.file_type
|
||||
assert actual.base_dir == expected.base_dir
|
||||
assert actual.connection_string == expected.connection_string
|
||||
assert actual.storage_account_blob_url == expected.storage_account_blob_url
|
||||
assert actual.container_name == expected.container_name
|
||||
assert actual.storage.base_dir == expected.storage.base_dir
|
||||
assert actual.storage.connection_string == expected.storage.connection_string
|
||||
assert (
|
||||
actual.storage.storage_account_blob_url
|
||||
== expected.storage.storage_account_blob_url
|
||||
)
|
||||
assert actual.storage.container_name == expected.storage.container_name
|
||||
assert actual.encoding == expected.encoding
|
||||
assert actual.file_pattern == expected.file_pattern
|
||||
assert actual.file_filter == expected.file_filter
|
||||
|
@ -1,56 +1,66 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
|
||||
|
||||
async def test_csv_loader_one_file():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (2, 4)
|
||||
assert documents["title"].iloc[0] == "input.csv"
|
||||
|
||||
|
||||
async def test_csv_loader_one_file_with_title():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
title_column="title",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (2, 4)
|
||||
assert documents["title"].iloc[0] == "Hello"
|
||||
|
||||
|
||||
async def test_csv_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
base_dir="tests/unit/indexing/input/data/one-csv",
|
||||
title_column="title",
|
||||
metadata=["title"],
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (2, 5)
|
||||
assert documents["metadata"][0] == {"title": "Hello"}
|
||||
|
||||
|
||||
async def test_csv_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-csvs",
|
||||
),
|
||||
file_type=InputFileType.csv,
|
||||
file_pattern=".*\\.csv$",
|
||||
base_dir="tests/unit/indexing/input/data/multiple-csvs",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (4, 4)
|
||||
|
@ -1,31 +1,37 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
|
||||
|
||||
async def test_json_loader_one_file_one_object():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "input.json"
|
||||
|
||||
|
||||
async def test_json_loader_one_file_multiple_objects():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-multiple-objects",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/one-json-multiple-objects",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
print(documents)
|
||||
assert documents.shape == (3, 4)
|
||||
assert documents["title"].iloc[0] == "input.json"
|
||||
@ -33,37 +39,43 @@ async def test_json_loader_one_file_multiple_objects():
|
||||
|
||||
async def test_json_loader_one_file_with_title():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
title_column="title",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "Hello"
|
||||
|
||||
|
||||
async def test_json_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/one-json-one-object",
|
||||
title_column="title",
|
||||
metadata=["title"],
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (1, 5)
|
||||
assert documents["metadata"][0] == {"title": "Hello"}
|
||||
|
||||
|
||||
async def test_json_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-jsons",
|
||||
),
|
||||
file_type=InputFileType.json,
|
||||
file_pattern=".*\\.json$",
|
||||
base_dir="tests/unit/indexing/input/data/multiple-jsons",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (4, 4)
|
||||
|
@ -1,32 +1,38 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.config.enums import InputFileType, InputType
|
||||
from graphrag.config.enums import InputFileType
|
||||
from graphrag.config.models.input_config import InputConfig
|
||||
from graphrag.config.models.storage_config import StorageConfig
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.utils.api import create_storage_from_config
|
||||
|
||||
|
||||
async def test_txt_loader_one_file():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (1, 4)
|
||||
assert documents["title"].iloc[0] == "input.txt"
|
||||
|
||||
|
||||
async def test_txt_loader_one_file_with_metadata():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
base_dir="tests/unit/indexing/input/data/one-txt",
|
||||
metadata=["title"],
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (1, 5)
|
||||
# unlike csv, we cannot set the title to anything other than the filename
|
||||
assert documents["metadata"][0] == {"title": "input.txt"}
|
||||
@ -34,10 +40,12 @@ async def test_txt_loader_one_file_with_metadata():
|
||||
|
||||
async def test_txt_loader_multiple_files():
|
||||
config = InputConfig(
|
||||
type=InputType.file,
|
||||
storage=StorageConfig(
|
||||
base_dir="tests/unit/indexing/input/data/multiple-txts",
|
||||
),
|
||||
file_type=InputFileType.text,
|
||||
file_pattern=".*\\.txt$",
|
||||
base_dir="tests/unit/indexing/input/data/multiple-txts",
|
||||
)
|
||||
documents = await create_input(config=config)
|
||||
storage = create_storage_from_config(config.storage)
|
||||
documents = await create_input(config=config, storage=storage)
|
||||
assert documents.shape == (2, 4)
|
||||
|
@ -23,7 +23,7 @@ async def test_create_base_text_units():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.storage)
|
||||
actual = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
compare_outputs(actual, expected, columns=["text", "document_ids", "n_tokens"])
|
||||
|
||||
@ -43,7 +43,7 @@ async def test_create_base_text_units_metadata():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.storage)
|
||||
actual = await load_table_from_storage("text_units", context.output_storage)
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
|
||||
@ -63,6 +63,6 @@ async def test_create_base_text_units_metadata_included_in_chunk():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.storage)
|
||||
actual = await load_table_from_storage("text_units", context.output_storage)
|
||||
# only check the columns from the base workflow - our expected table is the final and will have more
|
||||
compare_outputs(actual, expected, columns=["text", "document_ids", "n_tokens"])
|
||||
|
@ -33,7 +33,7 @@ async def test_create_communities():
|
||||
context,
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage("communities", context.storage)
|
||||
actual = await load_table_from_storage("communities", context.output_storage)
|
||||
|
||||
columns = list(expected.columns.values)
|
||||
# don't compare period since it is created with the current date each time
|
||||
|
@ -66,7 +66,7 @@ async def test_create_community_reports():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("community_reports", context.storage)
|
||||
actual = await load_table_from_storage("community_reports", context.output_storage)
|
||||
|
||||
assert len(actual.columns) == len(expected.columns)
|
||||
|
||||
|
@ -28,7 +28,7 @@ async def test_create_final_documents():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("documents", context.storage)
|
||||
actual = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
@ -47,11 +47,11 @@ async def test_create_final_documents_with_metadata_column():
|
||||
# simulate the metadata construction during initial input loading
|
||||
await update_document_metadata(config.input.metadata, context)
|
||||
|
||||
expected = await load_table_from_storage("documents", context.storage)
|
||||
expected = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("documents", context.storage)
|
||||
actual = await load_table_from_storage("documents", context.output_storage)
|
||||
|
||||
compare_outputs(actual, expected)
|
||||
|
||||
|
@ -33,7 +33,7 @@ async def test_create_final_text_units():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("text_units", context.storage)
|
||||
actual = await load_table_from_storage("text_units", context.output_storage)
|
||||
|
||||
for column in TEXT_UNITS_FINAL_COLUMNS:
|
||||
assert column in actual.columns
|
||||
|
@ -37,6 +37,7 @@ async def test_extract_covariates():
|
||||
).model_dump()
|
||||
llm_settings["type"] = ModelType.MockChat
|
||||
llm_settings["responses"] = MOCK_LLM_RESPONSES
|
||||
config.extract_claims.enabled = True
|
||||
config.extract_claims.strategy = {
|
||||
"type": "graph_intelligence",
|
||||
"llm": llm_settings,
|
||||
@ -45,7 +46,7 @@ async def test_extract_covariates():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
actual = await load_table_from_storage("covariates", context.storage)
|
||||
actual = await load_table_from_storage("covariates", context.output_storage)
|
||||
|
||||
for column in COVARIATES_FINAL_COLUMNS:
|
||||
assert column in actual.columns
|
||||
|
@ -63,8 +63,10 @@ async def test_extract_graph():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
edges_actual = await load_table_from_storage("relationships", context.storage)
|
||||
nodes_actual = await load_table_from_storage("entities", context.output_storage)
|
||||
edges_actual = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
|
||||
assert len(nodes_actual.columns) == 5
|
||||
assert len(edges_actual.columns) == 5
|
||||
|
@ -22,8 +22,10 @@ async def test_extract_graph_nlp():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
edges_actual = await load_table_from_storage("relationships", context.storage)
|
||||
nodes_actual = await load_table_from_storage("entities", context.output_storage)
|
||||
edges_actual = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
|
||||
# this will be the raw count of entities and edges with no pruning
|
||||
# with NLP it is deterministic, so we can assert exact row counts
|
||||
|
@ -25,8 +25,10 @@ async def test_finalize_graph():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
edges_actual = await load_table_from_storage("relationships", context.storage)
|
||||
nodes_actual = await load_table_from_storage("entities", context.output_storage)
|
||||
edges_actual = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
|
||||
assert len(nodes_actual) == 291
|
||||
assert len(edges_actual) == 452
|
||||
@ -51,8 +53,10 @@ async def test_finalize_graph_umap():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
edges_actual = await load_table_from_storage("relationships", context.storage)
|
||||
nodes_actual = await load_table_from_storage("entities", context.output_storage)
|
||||
edges_actual = await load_table_from_storage(
|
||||
"relationships", context.output_storage
|
||||
)
|
||||
|
||||
assert len(nodes_actual) == 291
|
||||
assert len(edges_actual) == 452
|
||||
@ -75,8 +79,8 @@ async def _prep_tables():
|
||||
# edit the tables to eliminate final fields that wouldn't be on the inputs
|
||||
entities = load_test_table("entities")
|
||||
entities.drop(columns=["x", "y", "degree"], inplace=True)
|
||||
await write_table_to_storage(entities, "entities", context.storage)
|
||||
await write_table_to_storage(entities, "entities", context.output_storage)
|
||||
relationships = load_test_table("relationships")
|
||||
relationships.drop(columns=["combined_degree"], inplace=True)
|
||||
await write_table_to_storage(relationships, "relationships", context.storage)
|
||||
await write_table_to_storage(relationships, "relationships", context.output_storage)
|
||||
return context
|
||||
|
@ -44,14 +44,14 @@ async def test_generate_text_embeddings():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
parquet_files = context.storage.keys()
|
||||
parquet_files = context.output_storage.keys()
|
||||
|
||||
for field in all_embeddings:
|
||||
assert f"embeddings.{field}.parquet" in parquet_files
|
||||
|
||||
# entity description should always be here, let's assert its format
|
||||
entity_description_embeddings = await load_table_from_storage(
|
||||
"embeddings.entity.description", context.storage
|
||||
"embeddings.entity.description", context.output_storage
|
||||
)
|
||||
|
||||
assert len(entity_description_embeddings.columns) == 2
|
||||
@ -60,7 +60,7 @@ async def test_generate_text_embeddings():
|
||||
|
||||
# every other embedding is optional but we've turned them all on, so check a random one
|
||||
document_text_embeddings = await load_table_from_storage(
|
||||
"embeddings.document.text", context.storage
|
||||
"embeddings.document.text", context.output_storage
|
||||
)
|
||||
|
||||
assert len(document_text_embeddings.columns) == 2
|
||||
|
@ -26,6 +26,6 @@ async def test_prune_graph():
|
||||
|
||||
await run_workflow(config, context)
|
||||
|
||||
nodes_actual = await load_table_from_storage("entities", context.storage)
|
||||
nodes_actual = await load_table_from_storage("entities", context.output_storage)
|
||||
|
||||
assert len(nodes_actual) == 21
|
||||
|
@ -38,12 +38,12 @@ async def create_test_context(storage: list[str] | None = None) -> PipelineRunCo
|
||||
# always set the input docs, but since our stored table is final, drop what wouldn't be in the original source input
|
||||
input = load_test_table("documents")
|
||||
input.drop(columns=["text_unit_ids"], inplace=True)
|
||||
await write_table_to_storage(input, "documents", context.storage)
|
||||
await write_table_to_storage(input, "documents", context.output_storage)
|
||||
|
||||
if storage:
|
||||
for name in storage:
|
||||
table = load_test_table(name)
|
||||
await write_table_to_storage(table, name, context.storage)
|
||||
await write_table_to_storage(table, name, context.output_storage)
|
||||
|
||||
return context
|
||||
|
||||
@ -86,8 +86,8 @@ def compare_outputs(
|
||||
|
||||
async def update_document_metadata(metadata: list[str], context: PipelineRunContext):
|
||||
"""Takes the default documents and adds the configured metadata columns for later parsing by the text units and final documents workflows."""
|
||||
documents = await load_table_from_storage("documents", context.storage)
|
||||
documents = await load_table_from_storage("documents", context.output_storage)
|
||||
documents["metadata"] = documents[metadata].apply(lambda row: row.to_dict(), axis=1)
|
||||
await write_table_to_storage(
|
||||
documents, "documents", context.storage
|
||||
documents, "documents", context.output_storage
|
||||
) # write to the runtime context storage only
|
||||
|
Loading…
x
Reference in New Issue
Block a user