Pipeline callbacks (#1729)

* Add pipeline_start and pipeline_end callbacks

* Collapse redundant callback/logger logic

* Remove redundant reporting config classes

* Remove a few out-of-date type ignores

* Semver

---------

Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
Nathan Evans 2025-02-25 15:07:51 -08:00 committed by GitHub
parent e40476153d
commit ede6a74546
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 150 additions and 156 deletions

View File

@ -0,0 +1,4 @@
{
"type": "minor",
"description": "Add pipeline_start and pipeline_end callbacks."
}

View File

@ -10,15 +10,17 @@ Backwards compatibility is not guaranteed at this time.
import logging
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.callbacks.reporting import create_pipeline_reporter
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import CacheType, IndexingMethod
from graphrag.config.enums import IndexingMethod
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.index.run.run_pipeline import run_pipeline
from graphrag.index.typing import PipelineRunResult, WorkflowFunction
from graphrag.index.run.utils import create_callback_chain
from graphrag.index.typing import WorkflowFunction
from graphrag.index.workflows.factory import PipelineFactory
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
log = logging.getLogger(__name__)
@ -51,13 +53,13 @@ async def build_index(
list[PipelineRunResult]
The list of pipeline run results
"""
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
)
logger = progress_logger or NullProgressLogger()
# create a pipeline reporter and add to any additional callbacks
# TODO: remove the type ignore once the new config engine has been refactored
callbacks = callbacks or []
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
callbacks.append(create_pipeline_reporter(config.reporting, None))
workflow_callbacks = create_callback_chain(callbacks, logger)
outputs: list[PipelineRunResult] = []
if memory_profile:
@ -65,22 +67,23 @@ async def build_index(
pipeline = PipelineFactory.create_pipeline(config, method)
workflow_callbacks.pipeline_start(pipeline.names())
async for output in run_pipeline(
pipeline,
config,
cache=pipeline_cache,
callbacks=callbacks,
logger=progress_logger,
callbacks=workflow_callbacks,
logger=logger,
is_update_run=is_update_run,
):
outputs.append(output)
if progress_logger:
if output.errors and len(output.errors) > 0:
progress_logger.error(output.workflow)
else:
progress_logger.success(output.workflow)
progress_logger.info(str(output.result))
if output.errors and len(output.errors) > 0:
logger.error(output.workflow)
else:
logger.success(output.workflow)
logger.info(str(output.result))
workflow_callbacks.pipeline_end(outputs)
return outputs

View File

@ -435,10 +435,11 @@ def local_search_streaming(
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.info(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
config_args=vector_store_args,
embedding_name=entity_description_embedding,
)
@ -453,7 +454,7 @@ def local_search_streaming(
entities=entities_,
relationships=read_indexer_relationships(relationships),
covariates={"claims": covariates_},
description_embedding_store=description_embedding_store, # type: ignore
description_embedding_store=description_embedding_store,
response_type=response_type,
system_prompt=prompt,
callbacks=callbacks,
@ -789,15 +790,16 @@ def drift_search_streaming(
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.info(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
config_args=vector_store_args,
embedding_name=entity_description_embedding,
)
full_content_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
config_args=vector_store_args,
embedding_name=community_full_content_embedding,
)
@ -815,7 +817,7 @@ def drift_search_streaming(
text_units=read_indexer_text_units(text_units),
entities=entities_,
relationships=read_indexer_relationships(relationships),
description_embedding_store=description_embedding_store, # type: ignore
description_embedding_store=description_embedding_store,
local_system_prompt=prompt,
reduce_system_prompt=reduce_prompt,
response_type=response_type,
@ -1104,10 +1106,11 @@ def basic_search_streaming(
vector_store_args = {}
for index, store in config.vector_store.items():
vector_store_args[index] = store.model_dump()
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
msg = f"Vector Store Args: {redact(vector_store_args)}"
logger.info(msg)
description_embedding_store = get_embedding_store(
config_args=vector_store_args, # type: ignore
config_args=vector_store_args,
embedding_name=text_unit_text_embedding,
)

View File

@ -24,11 +24,11 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
def __init__(
self,
connection_string: str | None,
container_name: str,
container_name: str | None,
blob_name: str = "",
base_dir: str | None = None,
storage_account_blob_url: str | None = None,
): # type: ignore
):
"""Create a new instance of the BlobStorageReporter class."""
if container_name is None:
msg = "No container name provided for blob storage."

View File

@ -4,12 +4,19 @@
"""A no-op implementation of WorkflowCallbacks."""
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.logger.progress import Progress
class NoopWorkflowCallbacks(WorkflowCallbacks):
"""A no-op implementation of WorkflowCallbacks."""
def pipeline_start(self, names: list[str]) -> None:
"""Execute this callback when a the entire pipeline starts."""
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
"""Execute this callback when the entire pipeline ends."""
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""

View File

@ -1,110 +1,39 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing 'PipelineReportingConfig', 'PipelineFileReportingConfig' and 'PipelineConsoleReportingConfig' models."""
"""A module containing the pipeline reporter factory."""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast
from pydantic import BaseModel, Field
from typing import TYPE_CHECKING
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
from graphrag.config.enums import ReportingType
from graphrag.config.models.reporting_config import ReportingConfig
if TYPE_CHECKING:
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
T = TypeVar("T")
class PipelineReportingConfig(BaseModel, Generic[T]):
"""Represent the reporting configuration for the pipeline."""
type: T
class PipelineFileReportingConfig(PipelineReportingConfig[Literal[ReportingType.file]]):
"""Represent the file reporting configuration for the pipeline."""
type: Literal[ReportingType.file] = ReportingType.file
"""The type of reporting."""
base_dir: str | None = Field(
description="The base directory for the reporting.", default=None
)
"""The base directory for the reporting."""
class PipelineConsoleReportingConfig(
PipelineReportingConfig[Literal[ReportingType.console]]
):
"""Represent the console reporting configuration for the pipeline."""
type: Literal[ReportingType.console] = ReportingType.console
"""The type of reporting."""
class PipelineBlobReportingConfig(PipelineReportingConfig[Literal[ReportingType.blob]]):
"""Represents the blob reporting configuration for the pipeline."""
type: Literal[ReportingType.blob] = ReportingType.blob
"""The type of reporting."""
connection_string: str | None = Field(
description="The blob reporting connection string for the reporting.",
default=None,
)
"""The blob reporting connection string for the reporting."""
container_name: str = Field(
description="The container name for reporting", default=""
)
"""The container name for reporting"""
storage_account_blob_url: str | None = Field(
description="The storage account blob url for reporting", default=None
)
"""The storage account blob url for reporting"""
base_dir: str | None = Field(
description="The base directory for the reporting.", default=None
)
"""The base directory for the reporting."""
PipelineReportingConfigTypes = (
PipelineFileReportingConfig
| PipelineConsoleReportingConfig
| PipelineBlobReportingConfig
)
def create_pipeline_reporter(
config: PipelineReportingConfig | None, root_dir: str | None
config: ReportingConfig | None, root_dir: str | None
) -> WorkflowCallbacks:
"""Create a logger for the given pipeline config."""
config = config or PipelineFileReportingConfig(base_dir="logs")
config = config or ReportingConfig(base_dir="logs", type=ReportingType.file)
match config.type:
case ReportingType.file:
config = cast("PipelineFileReportingConfig", config)
return FileWorkflowCallbacks(
str(Path(root_dir or "") / (config.base_dir or ""))
)
case ReportingType.console:
return ConsoleWorkflowCallbacks()
case ReportingType.blob:
config = cast("PipelineBlobReportingConfig", config)
return BlobWorkflowCallbacks(
config.connection_string,
config.container_name,
base_dir=config.base_dir,
storage_account_blob_url=config.storage_account_blob_url,
)
case _:
msg = f"Unknown reporting type: {config.type}"
raise ValueError(msg)

View File

@ -5,6 +5,7 @@
from typing import Protocol
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.logger.progress import Progress
@ -15,6 +16,14 @@ class WorkflowCallbacks(Protocol):
This base class is a "noop" implementation so that clients may implement just the callbacks they need.
"""
def pipeline_start(self, names: list[str]) -> None:
"""Execute this callback to signal when the entire pipeline starts."""
...
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
"""Execute this callback to signal when the entire pipeline ends."""
...
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
...

View File

@ -4,6 +4,7 @@
"""A module containing the WorkflowCallbacks registry."""
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.logger.progress import Progress
@ -20,6 +21,18 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
"""Register a new WorkflowCallbacks type."""
self._callbacks.append(callbacks)
def pipeline_start(self, names: list[str]) -> None:
"""Execute this callback when a the entire pipeline starts."""
for callback in self._callbacks:
if hasattr(callback, "pipeline_start"):
callback.pipeline_start(names)
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
"""Execute this callback when the entire pipeline ends."""
for callback in self._callbacks:
if hasattr(callback, "pipeline_end"):
callback.pipeline_end(results)
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
for callback in self._callbacks:

View File

@ -527,7 +527,7 @@ def _resolve_output_files(
return dataframe_dict
# Loading output files for single-index search
dataframe_dict["multi-index"] = False
output_config = config.output.model_dump() # type: ignore
output_config = config.output.model_dump()
storage_obj = StorageFactory().create_storage(
storage_type=output_config["type"], kwargs=output_config
)

View File

@ -0,0 +1,23 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the Pipeline class."""
from collections.abc import Generator
from graphrag.index.typing import Workflow
class Pipeline:
"""Encapsulates running workflows."""
def __init__(self, workflows: list[Workflow]):
self.workflows = workflows
def run(self) -> Generator[Workflow]:
"""Return a Generator over the pipeline workflows."""
yield from self.workflows
def names(self) -> list[str]:
"""Return the names of the workflows in the pipeline."""
return [name for name, _ in self.workflows]

View File

@ -0,0 +1,22 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing the PipelineRunResult class."""
from dataclasses import dataclass
from typing import Any
from graphrag.config.models.graph_rag_config import GraphRagConfig
@dataclass
class PipelineRunResult:
"""Pipeline run result class definition."""
workflow: str
"""The name of the workflow that was executed."""
result: Any | None
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
config: GraphRagConfig | None
"""Final config after running the workflow, which may have been mutated."""
errors: list[BaseException] | None

View File

@ -15,20 +15,19 @@ import pandas as pd
from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunStats
from graphrag.index.input.factory import create_input
from graphrag.index.run.utils import create_callback_chain, create_run_context
from graphrag.index.typing import Pipeline, PipelineRunResult
from graphrag.index.run.pipeline import Pipeline
from graphrag.index.run.pipeline_run_result import PipelineRunResult
from graphrag.index.run.utils import create_run_context
from graphrag.index.update.incremental_index import (
get_delta_docs,
update_dataframe_outputs,
)
from graphrag.logger.base import ProgressLogger
from graphrag.logger.null_progress import NullProgressLogger
from graphrag.logger.progress import Progress
from graphrag.storage.factory import StorageFactory
from graphrag.storage.pipeline_storage import PipelineStorage
@ -40,24 +39,21 @@ log = logging.getLogger(__name__)
async def run_pipeline(
pipeline: Pipeline,
config: GraphRagConfig,
cache: PipelineCache | None = None,
callbacks: list[WorkflowCallbacks] | None = None,
logger: ProgressLogger | None = None,
callbacks: WorkflowCallbacks,
logger: ProgressLogger,
is_update_run: bool = False,
) -> AsyncIterable[PipelineRunResult]:
"""Run all workflows using a simplified pipeline."""
root_dir = config.root_dir
progress_logger = logger or NullProgressLogger()
callbacks = callbacks or [ConsoleWorkflowCallbacks()]
callback_chain = create_callback_chain(callbacks, progress_logger)
storage_config = config.output.model_dump() # type: ignore
storage_config = config.output.model_dump()
storage = StorageFactory().create_storage(
storage_type=storage_config["type"], # type: ignore
storage_type=storage_config["type"],
kwargs=storage_config,
)
cache_config = config.cache.model_dump() # type: ignore
cache = cache or CacheFactory().create_cache(
cache_type=cache_config["type"], # type: ignore
cache_config = config.cache.model_dump()
cache = CacheFactory().create_cache(
cache_type=cache_config["type"],
root_dir=root_dir,
kwargs=cache_config,
)
@ -65,18 +61,18 @@ async def run_pipeline(
dataset = await create_input(config.input, logger, root_dir)
if is_update_run:
progress_logger.info("Running incremental indexing.")
logger.info("Running incremental indexing.")
delta_dataset = await get_delta_docs(dataset, storage)
# warn on empty delta dataset
if delta_dataset.new_inputs.empty:
warning_msg = "Incremental indexing found no new documents, exiting."
progress_logger.warning(warning_msg)
logger.warning(warning_msg)
else:
update_storage_config = config.update_index_output.model_dump() # type: ignore
update_storage_config = config.update_index_output.model_dump()
update_storage = StorageFactory().create_storage(
storage_type=update_storage_config["type"], # type: ignore
storage_type=update_storage_config["type"],
kwargs=update_storage_config,
)
# we use this to store the new subset index, and will merge its content with the previous index
@ -94,12 +90,12 @@ async def run_pipeline(
dataset=delta_dataset.new_inputs,
cache=cache,
storage=delta_storage,
callbacks=callback_chain,
logger=progress_logger,
callbacks=callbacks,
logger=logger,
):
yield table
progress_logger.success("Finished running workflows on new documents.")
logger.success("Finished running workflows on new documents.")
await update_dataframe_outputs(
previous_storage=previous_storage,
@ -108,11 +104,11 @@ async def run_pipeline(
config=config,
cache=cache,
callbacks=NoopWorkflowCallbacks(),
progress_logger=progress_logger,
progress_logger=logger,
)
else:
progress_logger.info("Running standard indexing.")
logger.info("Running standard indexing.")
async for table in _run_pipeline(
pipeline=pipeline,
@ -120,8 +116,8 @@ async def run_pipeline(
dataset=dataset,
cache=cache,
storage=storage,
callbacks=callback_chain,
logger=progress_logger,
callbacks=callbacks,
logger=logger,
):
yield table
@ -148,7 +144,7 @@ async def _run_pipeline(
await _dump_stats(context.stats, context.storage)
await write_table_to_storage(dataset, "documents", context.storage)
for name, workflow_function in pipeline:
for name, workflow_function in pipeline.run():
last_workflow = name
progress = logger.child(name, transient=False)
callbacks.workflow_start(name, None)

View File

@ -3,7 +3,7 @@
"""A module containing the 'PipelineRunResult' model."""
from collections.abc import Awaitable, Callable, Generator
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
@ -29,18 +29,3 @@ WorkflowFunction = Callable[
Awaitable[WorkflowFunctionOutput],
]
Workflow = tuple[str, WorkflowFunction]
Pipeline = Generator[Workflow]
@dataclass
class PipelineRunResult:
"""Pipeline run result class definition."""
workflow: str
"""The name of the workflow that was executed."""
result: Any | None
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
config: GraphRagConfig | None
"""Final config after running the workflow, which may have been mutated."""
errors: list[BaseException] | None

View File

@ -7,7 +7,8 @@ from typing import ClassVar
from graphrag.config.enums import IndexingMethod
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.typing import Pipeline, WorkflowFunction
from graphrag.index.run.pipeline import Pipeline
from graphrag.index.typing import WorkflowFunction
class PipelineFactory:
@ -32,8 +33,7 @@ class PipelineFactory:
) -> Pipeline:
"""Create a pipeline generator."""
workflows = _get_workflows_list(config, method)
for name in workflows:
yield name, cls.workflows[name]
return Pipeline([(name, cls.workflows[name]) for name in workflows])
def _get_workflows_list(

View File

@ -160,7 +160,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
),
)
graph = results.graph # type: ignore
graph = results.graph
assert graph is not None, "No graph returned!"
# TODO: The edges might come back in any order, but we're assuming they're coming
@ -210,7 +210,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
),
)
graph = results.graph # type: ignore
graph = results.graph
assert graph is not None, "No graph returned!"
edges = list(graph.edges(data=True))
@ -218,6 +218,6 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
assert len(edges) == 2
# Sort by source_id for consistent ordering
edge_source_ids = sorted([edge[2].get("source_id", "") for edge in edges]) # type: ignore
assert edge_source_ids[0].split(",") == ["1"] # type: ignore
assert edge_source_ids[1].split(",") == ["2"] # type: ignore
edge_source_ids = sorted([edge[2].get("source_id", "") for edge in edges])
assert edge_source_ids[0].split(",") == ["1"]
assert edge_source_ids[1].split(",") == ["2"]