mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Pipeline callbacks (#1729)
* Add pipeline_start and pipeline_end callbacks * Collapse redundant callback/logger logic * Remove redundant reporting config classes * Remove a few out-of-date type ignores * Semver --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
e40476153d
commit
ede6a74546
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Add pipeline_start and pipeline_end callbacks."
|
||||
}
|
@ -10,15 +10,17 @@ Backwards compatibility is not guaranteed at this time.
|
||||
|
||||
import logging
|
||||
|
||||
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from graphrag.callbacks.reporting import create_pipeline_reporter
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import CacheType, IndexingMethod
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.run.run_pipeline import run_pipeline
|
||||
from graphrag.index.typing import PipelineRunResult, WorkflowFunction
|
||||
from graphrag.index.run.utils import create_callback_chain
|
||||
from graphrag.index.typing import WorkflowFunction
|
||||
from graphrag.index.workflows.factory import PipelineFactory
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -51,13 +53,13 @@ async def build_index(
|
||||
list[PipelineRunResult]
|
||||
The list of pipeline run results
|
||||
"""
|
||||
pipeline_cache = (
|
||||
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
|
||||
)
|
||||
logger = progress_logger or NullProgressLogger()
|
||||
# create a pipeline reporter and add to any additional callbacks
|
||||
# TODO: remove the type ignore once the new config engine has been refactored
|
||||
callbacks = callbacks or []
|
||||
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
|
||||
callbacks.append(create_pipeline_reporter(config.reporting, None))
|
||||
|
||||
workflow_callbacks = create_callback_chain(callbacks, logger)
|
||||
|
||||
outputs: list[PipelineRunResult] = []
|
||||
|
||||
if memory_profile:
|
||||
@ -65,22 +67,23 @@ async def build_index(
|
||||
|
||||
pipeline = PipelineFactory.create_pipeline(config, method)
|
||||
|
||||
workflow_callbacks.pipeline_start(pipeline.names())
|
||||
|
||||
async for output in run_pipeline(
|
||||
pipeline,
|
||||
config,
|
||||
cache=pipeline_cache,
|
||||
callbacks=callbacks,
|
||||
logger=progress_logger,
|
||||
callbacks=workflow_callbacks,
|
||||
logger=logger,
|
||||
is_update_run=is_update_run,
|
||||
):
|
||||
outputs.append(output)
|
||||
if progress_logger:
|
||||
if output.errors and len(output.errors) > 0:
|
||||
progress_logger.error(output.workflow)
|
||||
else:
|
||||
progress_logger.success(output.workflow)
|
||||
progress_logger.info(str(output.result))
|
||||
if output.errors and len(output.errors) > 0:
|
||||
logger.error(output.workflow)
|
||||
else:
|
||||
logger.success(output.workflow)
|
||||
logger.info(str(output.result))
|
||||
|
||||
workflow_callbacks.pipeline_end(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
|
@ -435,10 +435,11 @@ def local_search_streaming(
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
config_args=vector_store_args,
|
||||
embedding_name=entity_description_embedding,
|
||||
)
|
||||
|
||||
@ -453,7 +454,7 @@ def local_search_streaming(
|
||||
entities=entities_,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
covariates={"claims": covariates_},
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
description_embedding_store=description_embedding_store,
|
||||
response_type=response_type,
|
||||
system_prompt=prompt,
|
||||
callbacks=callbacks,
|
||||
@ -789,15 +790,16 @@ def drift_search_streaming(
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
config_args=vector_store_args,
|
||||
embedding_name=entity_description_embedding,
|
||||
)
|
||||
|
||||
full_content_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
config_args=vector_store_args,
|
||||
embedding_name=community_full_content_embedding,
|
||||
)
|
||||
|
||||
@ -815,7 +817,7 @@ def drift_search_streaming(
|
||||
text_units=read_indexer_text_units(text_units),
|
||||
entities=entities_,
|
||||
relationships=read_indexer_relationships(relationships),
|
||||
description_embedding_store=description_embedding_store, # type: ignore
|
||||
description_embedding_store=description_embedding_store,
|
||||
local_system_prompt=prompt,
|
||||
reduce_system_prompt=reduce_prompt,
|
||||
response_type=response_type,
|
||||
@ -1104,10 +1106,11 @@ def basic_search_streaming(
|
||||
vector_store_args = {}
|
||||
for index, store in config.vector_store.items():
|
||||
vector_store_args[index] = store.model_dump()
|
||||
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
|
||||
msg = f"Vector Store Args: {redact(vector_store_args)}"
|
||||
logger.info(msg)
|
||||
|
||||
description_embedding_store = get_embedding_store(
|
||||
config_args=vector_store_args, # type: ignore
|
||||
config_args=vector_store_args,
|
||||
embedding_name=text_unit_text_embedding,
|
||||
)
|
||||
|
||||
|
@ -24,11 +24,11 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str | None,
|
||||
container_name: str,
|
||||
container_name: str | None,
|
||||
blob_name: str = "",
|
||||
base_dir: str | None = None,
|
||||
storage_account_blob_url: str | None = None,
|
||||
): # type: ignore
|
||||
):
|
||||
"""Create a new instance of the BlobStorageReporter class."""
|
||||
if container_name is None:
|
||||
msg = "No container name provided for blob storage."
|
||||
|
@ -4,12 +4,19 @@
|
||||
"""A no-op implementation of WorkflowCallbacks."""
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class NoopWorkflowCallbacks(WorkflowCallbacks):
|
||||
"""A no-op implementation of WorkflowCallbacks."""
|
||||
|
||||
def pipeline_start(self, names: list[str]) -> None:
|
||||
"""Execute this callback when a the entire pipeline starts."""
|
||||
|
||||
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
|
||||
"""Execute this callback when the entire pipeline ends."""
|
||||
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
|
||||
|
@ -1,110 +1,39 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing 'PipelineReportingConfig', 'PipelineFileReportingConfig' and 'PipelineConsoleReportingConfig' models."""
|
||||
"""A module containing the pipeline reporter factory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
|
||||
from graphrag.config.enums import ReportingType
|
||||
from graphrag.config.models.reporting_config import ReportingConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PipelineReportingConfig(BaseModel, Generic[T]):
|
||||
"""Represent the reporting configuration for the pipeline."""
|
||||
|
||||
type: T
|
||||
|
||||
|
||||
class PipelineFileReportingConfig(PipelineReportingConfig[Literal[ReportingType.file]]):
|
||||
"""Represent the file reporting configuration for the pipeline."""
|
||||
|
||||
type: Literal[ReportingType.file] = ReportingType.file
|
||||
"""The type of reporting."""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The base directory for the reporting.", default=None
|
||||
)
|
||||
"""The base directory for the reporting."""
|
||||
|
||||
|
||||
class PipelineConsoleReportingConfig(
|
||||
PipelineReportingConfig[Literal[ReportingType.console]]
|
||||
):
|
||||
"""Represent the console reporting configuration for the pipeline."""
|
||||
|
||||
type: Literal[ReportingType.console] = ReportingType.console
|
||||
"""The type of reporting."""
|
||||
|
||||
|
||||
class PipelineBlobReportingConfig(PipelineReportingConfig[Literal[ReportingType.blob]]):
|
||||
"""Represents the blob reporting configuration for the pipeline."""
|
||||
|
||||
type: Literal[ReportingType.blob] = ReportingType.blob
|
||||
"""The type of reporting."""
|
||||
|
||||
connection_string: str | None = Field(
|
||||
description="The blob reporting connection string for the reporting.",
|
||||
default=None,
|
||||
)
|
||||
"""The blob reporting connection string for the reporting."""
|
||||
|
||||
container_name: str = Field(
|
||||
description="The container name for reporting", default=""
|
||||
)
|
||||
"""The container name for reporting"""
|
||||
|
||||
storage_account_blob_url: str | None = Field(
|
||||
description="The storage account blob url for reporting", default=None
|
||||
)
|
||||
"""The storage account blob url for reporting"""
|
||||
|
||||
base_dir: str | None = Field(
|
||||
description="The base directory for the reporting.", default=None
|
||||
)
|
||||
"""The base directory for the reporting."""
|
||||
|
||||
|
||||
PipelineReportingConfigTypes = (
|
||||
PipelineFileReportingConfig
|
||||
| PipelineConsoleReportingConfig
|
||||
| PipelineBlobReportingConfig
|
||||
)
|
||||
|
||||
|
||||
def create_pipeline_reporter(
|
||||
config: PipelineReportingConfig | None, root_dir: str | None
|
||||
config: ReportingConfig | None, root_dir: str | None
|
||||
) -> WorkflowCallbacks:
|
||||
"""Create a logger for the given pipeline config."""
|
||||
config = config or PipelineFileReportingConfig(base_dir="logs")
|
||||
|
||||
config = config or ReportingConfig(base_dir="logs", type=ReportingType.file)
|
||||
match config.type:
|
||||
case ReportingType.file:
|
||||
config = cast("PipelineFileReportingConfig", config)
|
||||
return FileWorkflowCallbacks(
|
||||
str(Path(root_dir or "") / (config.base_dir or ""))
|
||||
)
|
||||
case ReportingType.console:
|
||||
return ConsoleWorkflowCallbacks()
|
||||
case ReportingType.blob:
|
||||
config = cast("PipelineBlobReportingConfig", config)
|
||||
return BlobWorkflowCallbacks(
|
||||
config.connection_string,
|
||||
config.container_name,
|
||||
base_dir=config.base_dir,
|
||||
storage_account_blob_url=config.storage_account_blob_url,
|
||||
)
|
||||
case _:
|
||||
msg = f"Unknown reporting type: {config.type}"
|
||||
raise ValueError(msg)
|
||||
|
@ -5,6 +5,7 @@
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
@ -15,6 +16,14 @@ class WorkflowCallbacks(Protocol):
|
||||
This base class is a "noop" implementation so that clients may implement just the callbacks they need.
|
||||
"""
|
||||
|
||||
def pipeline_start(self, names: list[str]) -> None:
|
||||
"""Execute this callback to signal when the entire pipeline starts."""
|
||||
...
|
||||
|
||||
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
|
||||
"""Execute this callback to signal when the entire pipeline ends."""
|
||||
...
|
||||
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
...
|
||||
|
@ -4,6 +4,7 @@
|
||||
"""A module containing the WorkflowCallbacks registry."""
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
@ -20,6 +21,18 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
|
||||
"""Register a new WorkflowCallbacks type."""
|
||||
self._callbacks.append(callbacks)
|
||||
|
||||
def pipeline_start(self, names: list[str]) -> None:
|
||||
"""Execute this callback when a the entire pipeline starts."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "pipeline_start"):
|
||||
callback.pipeline_start(names)
|
||||
|
||||
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
|
||||
"""Execute this callback when the entire pipeline ends."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "pipeline_end"):
|
||||
callback.pipeline_end(results)
|
||||
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
for callback in self._callbacks:
|
||||
|
@ -527,7 +527,7 @@ def _resolve_output_files(
|
||||
return dataframe_dict
|
||||
# Loading output files for single-index search
|
||||
dataframe_dict["multi-index"] = False
|
||||
output_config = config.output.model_dump() # type: ignore
|
||||
output_config = config.output.model_dump()
|
||||
storage_obj = StorageFactory().create_storage(
|
||||
storage_type=output_config["type"], kwargs=output_config
|
||||
)
|
||||
|
23
graphrag/index/run/pipeline.py
Normal file
23
graphrag/index/run/pipeline.py
Normal file
@ -0,0 +1,23 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the Pipeline class."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from graphrag.index.typing import Workflow
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""Encapsulates running workflows."""
|
||||
|
||||
def __init__(self, workflows: list[Workflow]):
|
||||
self.workflows = workflows
|
||||
|
||||
def run(self) -> Generator[Workflow]:
|
||||
"""Return a Generator over the pipeline workflows."""
|
||||
yield from self.workflows
|
||||
|
||||
def names(self) -> list[str]:
|
||||
"""Return the names of the workflows in the pipeline."""
|
||||
return [name for name, _ in self.workflows]
|
22
graphrag/index/run/pipeline_run_result.py
Normal file
22
graphrag/index/run/pipeline_run_result.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""A module containing the PipelineRunResult class."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunResult:
|
||||
"""Pipeline run result class definition."""
|
||||
|
||||
workflow: str
|
||||
"""The name of the workflow that was executed."""
|
||||
result: Any | None
|
||||
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
|
||||
config: GraphRagConfig | None
|
||||
"""Final config after running the workflow, which may have been mutated."""
|
||||
errors: list[BaseException] | None
|
@ -15,20 +15,19 @@ import pandas as pd
|
||||
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunStats
|
||||
from graphrag.index.input.factory import create_input
|
||||
from graphrag.index.run.utils import create_callback_chain, create_run_context
|
||||
from graphrag.index.typing import Pipeline, PipelineRunResult
|
||||
from graphrag.index.run.pipeline import Pipeline
|
||||
from graphrag.index.run.pipeline_run_result import PipelineRunResult
|
||||
from graphrag.index.run.utils import create_run_context
|
||||
from graphrag.index.update.incremental_index import (
|
||||
get_delta_docs,
|
||||
update_dataframe_outputs,
|
||||
)
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.null_progress import NullProgressLogger
|
||||
from graphrag.logger.progress import Progress
|
||||
from graphrag.storage.factory import StorageFactory
|
||||
from graphrag.storage.pipeline_storage import PipelineStorage
|
||||
@ -40,24 +39,21 @@ log = logging.getLogger(__name__)
|
||||
async def run_pipeline(
|
||||
pipeline: Pipeline,
|
||||
config: GraphRagConfig,
|
||||
cache: PipelineCache | None = None,
|
||||
callbacks: list[WorkflowCallbacks] | None = None,
|
||||
logger: ProgressLogger | None = None,
|
||||
callbacks: WorkflowCallbacks,
|
||||
logger: ProgressLogger,
|
||||
is_update_run: bool = False,
|
||||
) -> AsyncIterable[PipelineRunResult]:
|
||||
"""Run all workflows using a simplified pipeline."""
|
||||
root_dir = config.root_dir
|
||||
progress_logger = logger or NullProgressLogger()
|
||||
callbacks = callbacks or [ConsoleWorkflowCallbacks()]
|
||||
callback_chain = create_callback_chain(callbacks, progress_logger)
|
||||
storage_config = config.output.model_dump() # type: ignore
|
||||
|
||||
storage_config = config.output.model_dump()
|
||||
storage = StorageFactory().create_storage(
|
||||
storage_type=storage_config["type"], # type: ignore
|
||||
storage_type=storage_config["type"],
|
||||
kwargs=storage_config,
|
||||
)
|
||||
cache_config = config.cache.model_dump() # type: ignore
|
||||
cache = cache or CacheFactory().create_cache(
|
||||
cache_type=cache_config["type"], # type: ignore
|
||||
cache_config = config.cache.model_dump()
|
||||
cache = CacheFactory().create_cache(
|
||||
cache_type=cache_config["type"],
|
||||
root_dir=root_dir,
|
||||
kwargs=cache_config,
|
||||
)
|
||||
@ -65,18 +61,18 @@ async def run_pipeline(
|
||||
dataset = await create_input(config.input, logger, root_dir)
|
||||
|
||||
if is_update_run:
|
||||
progress_logger.info("Running incremental indexing.")
|
||||
logger.info("Running incremental indexing.")
|
||||
|
||||
delta_dataset = await get_delta_docs(dataset, storage)
|
||||
|
||||
# warn on empty delta dataset
|
||||
if delta_dataset.new_inputs.empty:
|
||||
warning_msg = "Incremental indexing found no new documents, exiting."
|
||||
progress_logger.warning(warning_msg)
|
||||
logger.warning(warning_msg)
|
||||
else:
|
||||
update_storage_config = config.update_index_output.model_dump() # type: ignore
|
||||
update_storage_config = config.update_index_output.model_dump()
|
||||
update_storage = StorageFactory().create_storage(
|
||||
storage_type=update_storage_config["type"], # type: ignore
|
||||
storage_type=update_storage_config["type"],
|
||||
kwargs=update_storage_config,
|
||||
)
|
||||
# we use this to store the new subset index, and will merge its content with the previous index
|
||||
@ -94,12 +90,12 @@ async def run_pipeline(
|
||||
dataset=delta_dataset.new_inputs,
|
||||
cache=cache,
|
||||
storage=delta_storage,
|
||||
callbacks=callback_chain,
|
||||
logger=progress_logger,
|
||||
callbacks=callbacks,
|
||||
logger=logger,
|
||||
):
|
||||
yield table
|
||||
|
||||
progress_logger.success("Finished running workflows on new documents.")
|
||||
logger.success("Finished running workflows on new documents.")
|
||||
|
||||
await update_dataframe_outputs(
|
||||
previous_storage=previous_storage,
|
||||
@ -108,11 +104,11 @@ async def run_pipeline(
|
||||
config=config,
|
||||
cache=cache,
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
progress_logger=progress_logger,
|
||||
progress_logger=logger,
|
||||
)
|
||||
|
||||
else:
|
||||
progress_logger.info("Running standard indexing.")
|
||||
logger.info("Running standard indexing.")
|
||||
|
||||
async for table in _run_pipeline(
|
||||
pipeline=pipeline,
|
||||
@ -120,8 +116,8 @@ async def run_pipeline(
|
||||
dataset=dataset,
|
||||
cache=cache,
|
||||
storage=storage,
|
||||
callbacks=callback_chain,
|
||||
logger=progress_logger,
|
||||
callbacks=callbacks,
|
||||
logger=logger,
|
||||
):
|
||||
yield table
|
||||
|
||||
@ -148,7 +144,7 @@ async def _run_pipeline(
|
||||
await _dump_stats(context.stats, context.storage)
|
||||
await write_table_to_storage(dataset, "documents", context.storage)
|
||||
|
||||
for name, workflow_function in pipeline:
|
||||
for name, workflow_function in pipeline.run():
|
||||
last_workflow = name
|
||||
progress = logger.child(name, transient=False)
|
||||
callbacks.workflow_start(name, None)
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
"""A module containing the 'PipelineRunResult' model."""
|
||||
|
||||
from collections.abc import Awaitable, Callable, Generator
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@ -29,18 +29,3 @@ WorkflowFunction = Callable[
|
||||
Awaitable[WorkflowFunctionOutput],
|
||||
]
|
||||
Workflow = tuple[str, WorkflowFunction]
|
||||
|
||||
Pipeline = Generator[Workflow]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRunResult:
|
||||
"""Pipeline run result class definition."""
|
||||
|
||||
workflow: str
|
||||
"""The name of the workflow that was executed."""
|
||||
result: Any | None
|
||||
"""The result of the workflow function. This can be anything - we use it only for logging downstream, and expect each workflow function to write official outputs to the provided storage."""
|
||||
config: GraphRagConfig | None
|
||||
"""Final config after running the workflow, which may have been mutated."""
|
||||
errors: list[BaseException] | None
|
||||
|
@ -7,7 +7,8 @@ from typing import ClassVar
|
||||
|
||||
from graphrag.config.enums import IndexingMethod
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.typing import Pipeline, WorkflowFunction
|
||||
from graphrag.index.run.pipeline import Pipeline
|
||||
from graphrag.index.typing import WorkflowFunction
|
||||
|
||||
|
||||
class PipelineFactory:
|
||||
@ -32,8 +33,7 @@ class PipelineFactory:
|
||||
) -> Pipeline:
|
||||
"""Create a pipeline generator."""
|
||||
workflows = _get_workflows_list(config, method)
|
||||
for name in workflows:
|
||||
yield name, cls.workflows[name]
|
||||
return Pipeline([(name, cls.workflows[name]) for name in workflows])
|
||||
|
||||
|
||||
def _get_workflows_list(
|
||||
|
@ -160,7 +160,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
),
|
||||
)
|
||||
|
||||
graph = results.graph # type: ignore
|
||||
graph = results.graph
|
||||
assert graph is not None, "No graph returned!"
|
||||
|
||||
# TODO: The edges might come back in any order, but we're assuming they're coming
|
||||
@ -210,7 +210,7 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
),
|
||||
)
|
||||
|
||||
graph = results.graph # type: ignore
|
||||
graph = results.graph
|
||||
assert graph is not None, "No graph returned!"
|
||||
edges = list(graph.edges(data=True))
|
||||
|
||||
@ -218,6 +218,6 @@ class TestRunChain(unittest.IsolatedAsyncioTestCase):
|
||||
assert len(edges) == 2
|
||||
|
||||
# Sort by source_id for consistent ordering
|
||||
edge_source_ids = sorted([edge[2].get("source_id", "") for edge in edges]) # type: ignore
|
||||
assert edge_source_ids[0].split(",") == ["1"] # type: ignore
|
||||
assert edge_source_ids[1].split(",") == ["2"] # type: ignore
|
||||
edge_source_ids = sorted([edge[2].get("source_id", "") for edge in edges])
|
||||
assert edge_source_ids[0].split(",") == ["1"]
|
||||
assert edge_source_ids[1].split(",") == ["2"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user