Refactor callbacks (#1583)

* Unify Workflow and Verb callbacks interfaces

* Semver

* Fix storage class instantiation (#1582)

---------

Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
This commit is contained in:
Nathan Evans 2025-01-06 10:58:59 -08:00 committed by GitHub
parent cbb8f8788e
commit 7ec9ef0261
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 193 additions and 367 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Simplify callbacks model."
}

View File

@ -207,7 +207,7 @@
"outputs": [],
"source": [
"from graphrag.cache.factory import create_cache\n",
"from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks\n",
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
"\n",
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
@ -219,7 +219,7 @@
"config = workflow.config\n",
"text_embed = config.get(\"text_embed\", {})\n",
"embedded_fields = config.get(\"embedded_fields\", {})\n",
"callbacks = NoopVerbCallbacks()\n",
"callbacks = NoopWorkflowCallbacks()\n",
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
"\n",
"await generate_text_embeddings(\n",

View File

@ -13,7 +13,7 @@ Backwards compatibility is not guaranteed at this time.
from pydantic import PositiveInt, validate_call
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm
from graphrag.logger.print_progress import PrintProgressLogger
@ -99,7 +99,7 @@ async def generate_indexing_prompts(
"prompt_tuning",
config.llm,
cache=None,
callbacks=NoopVerbCallbacks(),
callbacks=NoopWorkflowCallbacks(),
)
if not domain:

View File

@ -84,7 +84,7 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
# update the blob's block count
self._num_blocks += 1
def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
@ -100,10 +100,10 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
"details": details,
})
def on_warning(self, message: str, details: dict | None = None):
def warning(self, message: str, details: dict | None = None):
"""Report a warning."""
self._write_log({"type": "warning", "data": message, "details": details})
def on_log(self, message: str, details: dict | None = None):
def log(self, message: str, details: dict | None = None):
"""Report a generic log message."""
self._write_log({"type": "log", "data": message, "details": details})

View File

@ -9,7 +9,7 @@ from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
"""A logger that writes to a console."""
def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
@ -19,11 +19,11 @@ class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
"""Handle when an error occurs."""
print(message, str(cause), stack, details) # noqa T201
def on_warning(self, message: str, details: dict | None = None):
def warning(self, message: str, details: dict | None = None):
"""Handle when a warning occurs."""
_print_warning(message)
def on_log(self, message: str, details: dict | None = None):
def log(self, message: str, details: dict | None = None):
"""Handle when a log message is produced."""
print(message, details) # noqa T201

View File

@ -1,46 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Contains the DelegatingVerbCallback definition."""
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.logger.progress import Progress
class DelegatingVerbCallbacks(VerbCallbacks):
"""A wrapper that implements VerbCallbacks that delegates to the underlying WorkflowCallbacks."""
_workflow_callbacks: WorkflowCallbacks
_name: str
def __init__(self, name: str, workflow_callbacks: WorkflowCallbacks):
"""Create a new instance of DelegatingVerbCallbacks."""
self._workflow_callbacks = workflow_callbacks
self._name = name
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
self._workflow_callbacks.on_step_progress(self._name, progress)
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
) -> None:
"""Handle when an error occurs."""
self._workflow_callbacks.on_error(message, cause, stack, details)
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""
self._workflow_callbacks.on_warning(message, details)
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log occurs."""
self._workflow_callbacks.on_log(message, details)
def measure(self, name: str, value: float, details: dict | None = None) -> None:
"""Handle when a measurement occurs."""
self._workflow_callbacks.on_measure(name, value, details)

View File

@ -25,7 +25,7 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks):
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
)
def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
@ -50,7 +50,7 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks):
message = f"{message} details={details}"
log.info(message)
def on_warning(self, message: str, details: dict | None = None):
def warning(self, message: str, details: dict | None = None):
"""Handle when a warning occurs."""
self._out_stream.write(
json.dumps(
@ -61,7 +61,7 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks):
)
_print_warning(message)
def on_log(self, message: str, details: dict | None = None):
def log(self, message: str, details: dict | None = None):
"""Handle when a log message is produced."""
self._out_stream.write(
json.dumps(

View File

@ -1,35 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Defines the interface for verb callbacks."""
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.logger.progress import Progress
class NoopVerbCallbacks(VerbCallbacks):
"""A noop implementation of the verb callbacks."""
def __init__(self) -> None:
pass
def progress(self, progress: Progress) -> None:
"""Report a progress update from the verb execution"."""
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
) -> None:
"""Report a error from the verb execution."""
def warning(self, message: str, details: dict | None = None) -> None:
"""Report a warning from verb execution."""
def log(self, message: str, details: dict | None = None) -> None:
"""Report an informational message from the verb execution."""
def measure(self, name: str, value: float) -> None:
"""Report a telemetry measurement from the verb execution."""

View File

@ -3,8 +3,6 @@
"""A no-op implementation of WorkflowCallbacks."""
from typing import Any
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.logger.progress import Progress
@ -12,22 +10,16 @@ from graphrag.logger.progress import Progress
class NoopWorkflowCallbacks(WorkflowCallbacks):
"""A no-op implementation of WorkflowCallbacks."""
def on_workflow_start(self, name: str, instance: object) -> None:
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
def on_workflow_end(self, name: str, instance: object) -> None:
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""
def on_step_start(self, step_name: str) -> None:
"""Execute this callback every time a step starts."""
def on_step_end(self, step_name: str, result: Any) -> None:
"""Execute this callback every time a step ends."""
def on_step_progress(self, step_name: str, progress: Progress) -> None:
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
@ -36,11 +28,8 @@ class NoopWorkflowCallbacks(WorkflowCallbacks):
) -> None:
"""Handle when an error occurs."""
def on_warning(self, message: str, details: dict | None = None) -> None:
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""
def on_log(self, message: str, details: dict | None = None) -> None:
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log message occurs."""
def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
"""Handle when a measurement occurs."""

View File

@ -3,8 +3,6 @@
"""A workflow callback manager that emits updates."""
from typing import Any
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.logger.base import ProgressLogger
from graphrag.logger.progress import Progress
@ -31,23 +29,14 @@ class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
def _latest(self) -> ProgressLogger:
return self._progress_stack[-1]
def on_workflow_start(self, name: str, instance: object) -> None:
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
self._push(name)
def on_workflow_end(self, name: str, instance: object) -> None:
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""
self._pop()
def on_step_start(self, step_name: str) -> None:
"""Execute this callback every time a step starts."""
self._push(f"Step {step_name}")
self._latest(Progress(percent=0))
def on_step_end(self, step_name: str, result: Any) -> None:
"""Execute this callback every time a step ends."""
self._pop()
def on_step_progress(self, step_name: str, progress: Progress) -> None:
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
self._latest(progress)

View File

@ -1,38 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Defines the interface for verb callbacks."""
from typing import Protocol
from graphrag.logger.progress import Progress
class VerbCallbacks(Protocol):
"""Provides a way to report status updates from the pipeline."""
def progress(self, progress: Progress) -> None:
"""Report a progress update from the verb execution"."""
...
def error(
self,
message: str,
cause: BaseException | None = None,
stack: str | None = None,
details: dict | None = None,
) -> None:
"""Report a error from the verb execution."""
...
def warning(self, message: str, details: dict | None = None) -> None:
"""Report a warning from verb execution."""
...
def log(self, message: str, details: dict | None = None) -> None:
"""Report an informational message from the verb execution."""
...
def measure(self, name: str, value: float) -> None:
"""Report a telemetry measurement from the verb execution."""
...

View File

@ -3,7 +3,7 @@
"""Collection of callbacks that can be used to monitor the workflow execution."""
from typing import Any, Protocol
from typing import Protocol
from graphrag.logger.progress import Progress
@ -15,27 +15,19 @@ class WorkflowCallbacks(Protocol):
This base class is a "noop" implementation so that clients may implement just the callbacks they need.
"""
def on_workflow_start(self, name: str, instance: object) -> None:
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
...
def on_workflow_end(self, name: str, instance: object) -> None:
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""
...
def on_step_start(self, step_name: str) -> None:
"""Execute this callback every time a step starts."""
...
def on_step_end(self, step_name: str, result: Any) -> None:
"""Execute this callback every time a step ends."""
...
def on_step_progress(self, step_name: str, progress: Progress) -> None:
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
...
def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
@ -45,14 +37,10 @@ class WorkflowCallbacks(Protocol):
"""Handle when an error occurs."""
...
def on_warning(self, message: str, details: dict | None = None) -> None:
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""
...
def on_log(self, message: str, details: dict | None = None) -> None:
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log message occurs."""
...
def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
"""Handle when a measurement occurs."""
...

View File

@ -3,8 +3,6 @@
"""A module containing the WorkflowCallbacks registry."""
from typing import Any
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.logger.progress import Progress
@ -22,37 +20,25 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
"""Register a new WorkflowCallbacks type."""
self._callbacks.append(callbacks)
def on_workflow_start(self, name: str, instance: object) -> None:
def workflow_start(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow starts."""
for callback in self._callbacks:
if hasattr(callback, "on_workflow_start"):
callback.on_workflow_start(name, instance)
if hasattr(callback, "workflow_start"):
callback.workflow_start(name, instance)
def on_workflow_end(self, name: str, instance: object) -> None:
def workflow_end(self, name: str, instance: object) -> None:
"""Execute this callback when a workflow ends."""
for callback in self._callbacks:
if hasattr(callback, "on_workflow_end"):
callback.on_workflow_end(name, instance)
if hasattr(callback, "workflow_end"):
callback.workflow_end(name, instance)
def on_step_start(self, step_name: str) -> None:
"""Execute this callback every time a step starts."""
for callback in self._callbacks:
if hasattr(callback, "on_step_start"):
callback.on_step_start(step_name)
def on_step_end(self, step_name: str, result: Any) -> None:
"""Execute this callback every time a step ends."""
for callback in self._callbacks:
if hasattr(callback, "on_step_end"):
callback.on_step_end(step_name, result)
def on_step_progress(self, step_name: str, progress: Progress) -> None:
def progress(self, progress: Progress) -> None:
"""Handle when progress occurs."""
for callback in self._callbacks:
if hasattr(callback, "on_step_progress"):
callback.on_step_progress(step_name, progress)
if hasattr(callback, "progress"):
callback.progress(progress)
def on_error(
def error(
self,
message: str,
cause: BaseException | None = None,
@ -61,23 +47,17 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
) -> None:
"""Handle when an error occurs."""
for callback in self._callbacks:
if hasattr(callback, "on_error"):
callback.on_error(message, cause, stack, details)
if hasattr(callback, "error"):
callback.error(message, cause, stack, details)
def on_warning(self, message: str, details: dict | None = None) -> None:
def warning(self, message: str, details: dict | None = None) -> None:
"""Handle when a warning occurs."""
for callback in self._callbacks:
if hasattr(callback, "on_warning"):
callback.on_warning(message, details)
if hasattr(callback, "warning"):
callback.warning(message, details)
def on_log(self, message: str, details: dict | None = None) -> None:
def log(self, message: str, details: dict | None = None) -> None:
"""Handle when a log message occurs."""
for callback in self._callbacks:
if hasattr(callback, "on_log"):
callback.on_log(message, details)
def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
"""Handle when a measurement occurs."""
for callback in self._callbacks:
if hasattr(callback, "on_measure"):
callback.on_measure(name, value, details)
if hasattr(callback, "log"):
callback.log(message, details)

View File

@ -7,7 +7,7 @@ from typing import cast
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.chunking_config import ChunkStrategyType
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
from graphrag.index.utils.hashing import gen_sha512_hash
@ -16,7 +16,7 @@ from graphrag.logger.progress import Progress
def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
group_by_columns: list[str],
size: int,
overlap: int,

View File

@ -8,7 +8,7 @@ from uuid import uuid4
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
@ -43,7 +43,7 @@ async def create_final_community_reports(
entities: pd.DataFrame,
communities: pd.DataFrame,
claims_input: pd.DataFrame | None,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
summarization_strategy: dict,
async_mode: AsyncType = AsyncType.AsyncIO,

View File

@ -9,7 +9,7 @@ from uuid import uuid4
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.operations.extract_covariates.extract_covariates import (
extract_covariates,
@ -18,7 +18,7 @@ from graphrag.index.operations.extract_covariates.extract_covariates import (
async def create_final_covariates(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
covariate_type: str,
extraction_strategy: dict[str, Any] | None,

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.create_graph import create_graph
@ -17,7 +17,7 @@ def create_final_nodes(
base_entity_nodes: pd.DataFrame,
base_relationship_edges: pd.DataFrame,
base_communities: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
embed_config: EmbedGraphConfig,
layout_enabled: bool,
) -> pd.DataFrame:

View File

@ -9,7 +9,7 @@ from uuid import uuid4
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.summarize_descriptions import (
@ -19,7 +19,7 @@ from graphrag.index.operations.summarize_descriptions import (
async def extract_graph(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,

View File

@ -8,7 +8,7 @@ import logging
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.config.embeddings import (
community_full_content_embedding,
community_summary_embedding,
@ -32,7 +32,7 @@ async def generate_text_embeddings(
final_text_units: pd.DataFrame | None,
final_entities: pd.DataFrame | None,
final_community_reports: pd.DataFrame | None,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict,
@ -110,7 +110,7 @@ async def _run_and_snapshot_embeddings(
name: str,
data: pd.DataFrame,
embed_column: str,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
text_embed_config: dict,

View File

@ -30,7 +30,7 @@ from .mock_llm import MockChatLLM
if TYPE_CHECKING:
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.typing import ErrorHandlerFn
log = logging.getLogger(__name__)
@ -105,7 +105,7 @@ def load_llm(
name: str,
config: LLMParameters,
*,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache | None,
chat_only=False,
) -> ChatLLM:
@ -135,7 +135,7 @@ def load_llm_embeddings(
name: str,
llm_config: LLMParameters,
*,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache | None,
chat_only=False,
) -> EmbeddingsLLM:
@ -160,7 +160,7 @@ def load_llm_embeddings(
raise ValueError(msg)
def _create_error_handler(callbacks: VerbCallbacks) -> ErrorHandlerFn:
def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
def on_error(
error: BaseException | None = None,
stack: str | None = None,

View File

@ -7,7 +7,7 @@ from typing import Any, cast
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
from graphrag.index.operations.chunk_text.typing import (
ChunkInput,
@ -23,7 +23,7 @@ def chunk_text(
overlap: int,
encoding_model: str,
strategy: ChunkStrategyType,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> pd.Series:
"""
Chunk a piece of text into smaller pieces.

View File

@ -11,7 +11,7 @@ import numpy as np
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
from graphrag.utils.embeddings import create_collection_name
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
@ -37,7 +37,7 @@ class TextEmbedStrategyType(str, Enum):
async def embed_text(
input: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
embed_column: str,
strategy: dict,
@ -109,7 +109,7 @@ async def embed_text(
async def _text_embed_in_memory(
input: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
embed_column: str,
strategy: dict,
@ -126,7 +126,7 @@ async def _text_embed_in_memory(
async def _text_embed_with_vector_store(
input: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
embed_column: str,
strategy: dict[str, Any],

View File

@ -8,14 +8,14 @@ from collections.abc import Iterable
from typing import Any
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult
from graphrag.logger.progress import ProgressTicker, progress_ticker
async def run( # noqa RUF029 async is required for interface
input: list[str],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
_args: dict[str, Any],
) -> TextEmbeddingResult:

View File

@ -13,7 +13,7 @@ from pydantic import TypeAdapter
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.index.llm.load_llm import load_llm_embeddings
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult
@ -26,7 +26,7 @@ log = logging.getLogger(__name__)
async def run(
input: list[str],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
args: dict[str, Any],
) -> TextEmbeddingResult:
@ -75,7 +75,7 @@ def _get_splitter(config: LLMParameters, batch_max_tokens: int) -> TokenTextSpli
def _get_llm(
config: LLMParameters,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
) -> EmbeddingsLLM:
return load_llm_embeddings(

View File

@ -7,7 +7,7 @@ from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@dataclass
@ -20,7 +20,7 @@ class TextEmbeddingResult:
TextEmbeddingStrategy = Callable[
[
list[str],
VerbCallbacks,
WorkflowCallbacks,
PipelineCache,
dict,
],

View File

@ -12,7 +12,7 @@ import pandas as pd
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor
@ -30,7 +30,7 @@ DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
async def extract_covariates(
input: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
column: str,
covariate_type: str,
@ -78,7 +78,7 @@ async def run_claim_extraction(
input: str | Iterable[str],
entity_types: list[str],
resolved_entities_map: dict[str, str],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
strategy_config: dict[str, Any],
) -> CovariateExtractionResult:

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from typing import Any
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
@dataclass
@ -41,7 +41,7 @@ CovariateExtractStrategy = Callable[
Iterable[str],
list[str],
dict[str, str],
VerbCallbacks,
WorkflowCallbacks,
PipelineCache,
dict[str, Any],
],

View File

@ -9,7 +9,7 @@ from typing import Any
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.bootstrap import bootstrap
from graphrag.index.operations.extract_entities.typing import (
@ -27,7 +27,7 @@ DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
async def extract_entities(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
text_column: str,
id_column: str,

View File

@ -8,7 +8,7 @@ from fnllm import ChatLLM
import graphrag.config.defaults as defs
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor
from graphrag.index.operations.extract_entities.typing import (
@ -22,7 +22,7 @@ from graphrag.index.operations.extract_entities.typing import (
async def run_graph_intelligence(
docs: list[Document],
entity_types: EntityTypes,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> EntityExtractionResult:
@ -36,7 +36,7 @@ async def run_extract_entities(
llm: ChatLLM,
docs: list[Document],
entity_types: EntityTypes,
callbacks: VerbCallbacks | None,
callbacks: WorkflowCallbacks | None,
args: StrategyConfig,
) -> EntityExtractionResult:
"""Run the entity extraction chain."""

View File

@ -8,7 +8,7 @@ import nltk
from nltk.corpus import words
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.extract_entities.typing import (
Document,
EntityExtractionResult,
@ -23,7 +23,7 @@ words.ensure_loaded()
async def run( # noqa RUF029 async is required for interface
docs: list[Document],
entity_types: EntityTypes,
callbacks: VerbCallbacks, # noqa ARG001
callbacks: WorkflowCallbacks, # noqa ARG001
cache: PipelineCache, # noqa ARG001
args: StrategyConfig, # noqa ARG001
) -> EntityExtractionResult:

View File

@ -11,7 +11,7 @@ from typing import Any
import networkx as nx
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
ExtractedEntity = dict[str, Any]
ExtractedRelationship = dict[str, Any]
@ -40,7 +40,7 @@ EntityExtractStrategy = Callable[
[
list[Document],
EntityTypes,
VerbCallbacks,
WorkflowCallbacks,
PipelineCache,
StrategyConfig,
],

View File

@ -6,14 +6,14 @@
import networkx as nx
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
from graphrag.index.operations.layout_graph.typing import GraphLayout
def layout_graph(
graph: nx.Graph,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
enabled: bool,
embeddings: NodeEmbeddings | None,
):
@ -58,7 +58,7 @@ def _run_layout(
graph: nx.Graph,
enabled: bool,
embeddings: NodeEmbeddings,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> GraphLayout:
if enabled:
from graphrag.index.operations.layout_graph.umap import (

View File

@ -8,7 +8,7 @@ import logging
import pandas as pd
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import (
parallel_sort_context_batch,
)
@ -24,7 +24,7 @@ def prepare_community_reports(
nodes,
edges,
claims,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
max_tokens: int = 16_000,
):
"""Prep communities for report generation."""

View File

@ -9,7 +9,7 @@ import traceback
from fnllm import ChatLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import (
CommunityReportsExtractor,
@ -28,7 +28,7 @@ async def run_graph_intelligence(
community: str | int,
input: str,
level: int,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> CommunityReport | None:
@ -44,7 +44,7 @@ async def _run_extractor(
input: str,
level: int,
args: StrategyConfig,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> CommunityReport | None:
# RateLimiter
rate_limiter = RateLimiter(rate=1, per=60)

View File

@ -10,8 +10,8 @@ import pandas as pd
import graphrag.config.defaults as defaults
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
prep_community_report_context,
@ -34,7 +34,7 @@ async def summarize_communities(
local_contexts,
nodes,
community_hierarchy,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
strategy: dict,
async_mode: AsyncType = AsyncType.AsyncIO,
@ -73,7 +73,7 @@ async def summarize_communities(
local_reports = await derive_from_rows(
level_contexts,
run_generate,
callbacks=NoopVerbCallbacks(),
callbacks=NoopWorkflowCallbacks(),
num_threads=num_threads,
async_type=async_mode,
)
@ -84,7 +84,7 @@ async def summarize_communities(
async def _generate_report(
runner: CommunityReportsStrategy,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
strategy: dict,
community_id: int,

View File

@ -10,7 +10,7 @@ from typing import Any
from typing_extensions import TypedDict
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
ExtractedEntity = dict[str, Any]
StrategyConfig = dict[str, Any]
@ -45,7 +45,7 @@ CommunityReportsStrategy = Callable[
str | int,
str,
int,
VerbCallbacks,
WorkflowCallbacks,
PipelineCache,
StrategyConfig,
],

View File

@ -6,7 +6,7 @@
from fnllm import ChatLLM
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.llm.load_llm import load_llm, read_llm_params
from graphrag.index.operations.summarize_descriptions.description_summary_extractor import (
SummarizeExtractor,
@ -20,7 +20,7 @@ from graphrag.index.operations.summarize_descriptions.typing import (
async def run_graph_intelligence(
id: str | tuple[str, str],
descriptions: list[str],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
args: StrategyConfig,
) -> SummarizedDescriptionResult:
@ -36,7 +36,7 @@ async def run_summarize_descriptions(
llm: ChatLLM,
id: str | tuple[str, str],
descriptions: list[str],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
args: StrategyConfig,
) -> SummarizedDescriptionResult:
"""Run the entity extraction chain."""

View File

@ -10,7 +10,7 @@ from typing import Any
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.index.operations.summarize_descriptions.typing import (
SummarizationStrategy,
SummarizeStrategyType,
@ -23,7 +23,7 @@ log = logging.getLogger(__name__)
async def summarize_descriptions(
entities_df: pd.DataFrame,
relationships_df: pd.DataFrame,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
cache: PipelineCache,
strategy: dict[str, Any] | None = None,
num_threads: int = 4,

View File

@ -9,7 +9,7 @@ from enum import Enum
from typing import Any, NamedTuple
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
StrategyConfig = dict[str, Any]
@ -26,7 +26,7 @@ SummarizationStrategy = Callable[
[
str | tuple[str, str],
list[str],
VerbCallbacks,
WorkflowCallbacks,
PipelineCache,
StrategyConfig,
],

View File

@ -12,7 +12,7 @@ from typing import Any, TypeVar, cast
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.enums import AsyncType
from graphrag.logger.progress import progress_ticker
@ -32,7 +32,7 @@ class ParallelizationError(ValueError):
async def derive_from_rows(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
num_threads: int = 4,
async_type: AsyncType = AsyncType.AsyncIO,
) -> list[ItemType | None]:
@ -57,7 +57,7 @@ async def derive_from_rows(
async def derive_from_rows_asyncio_threads(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
num_threads: int | None = 4,
) -> list[ItemType | None]:
"""
@ -87,7 +87,7 @@ async def derive_from_rows_asyncio_threads(
async def derive_from_rows_asyncio(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
num_threads: int = 4,
) -> list[ItemType | None]:
"""
@ -121,7 +121,7 @@ GatherFn = Callable[[ExecuteFn], Awaitable[list[ItemType | None]]]
async def _derive_from_rows_base(
input: pd.DataFrame,
transform: Callable[[pd.Series], Awaitable[ItemType]],
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
gather: GatherFn[ItemType],
) -> list[ItemType | None]:
"""

View File

@ -9,15 +9,13 @@ import time
import traceback
from collections.abc import AsyncIterable
from dataclasses import asdict
from typing import cast
import pandas as pd
from graphrag.cache.factory import CacheFactory
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.callbacks.delegating_verb_callbacks import DelegatingVerbCallbacks
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunStats
@ -119,7 +117,7 @@ async def run_workflows(
update_storage=update_index_storage,
config=config,
cache=cache,
callbacks=NoopVerbCallbacks(),
callbacks=NoopWorkflowCallbacks(),
progress_logger=progress_logger,
)
@ -163,16 +161,15 @@ async def _run_workflows(
last_workflow = workflow
run_workflow = all_workflows[workflow]
progress = logger.child(workflow, transient=False)
callbacks.on_workflow_start(workflow, None)
verb_callbacks = DelegatingVerbCallbacks(workflow, callbacks)
callbacks.workflow_start(workflow, None)
work_time = time.time()
result = await run_workflow(
config,
context,
verb_callbacks,
callbacks,
)
progress(Progress(percent=1))
callbacks.on_workflow_end(workflow, result)
callbacks.workflow_end(workflow, result)
yield PipelineRunResult(workflow, result, None)
context.stats.workflows[workflow] = {"overall": time.time() - work_time}
@ -186,9 +183,7 @@ async def _run_workflows(
except Exception as e:
log.exception("error running workflow %s", last_workflow)
cast("WorkflowCallbacks", callbacks).on_error(
"Error running pipeline!", e, traceback.format_exc()
)
callbacks.error("Error running pipeline!", e, traceback.format_exc())
yield PipelineRunResult(last_workflow, None, [e])

View File

@ -10,7 +10,7 @@ import numpy as np
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import (
run_graph_intelligence as run_entity_summarization,
@ -92,7 +92,7 @@ async def _run_entity_summarization(
entities_df: pd.DataFrame,
config: GraphRagConfig,
cache: PipelineCache,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> pd.DataFrame:
"""Run entity summarization.

View File

@ -9,7 +9,7 @@ import numpy as np
import pandas as pd
from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings
@ -86,7 +86,7 @@ async def update_dataframe_outputs(
update_storage: PipelineStorage,
config: GraphRagConfig,
cache: PipelineCache,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
progress_logger: ProgressLogger,
) -> None:
"""Update the mergeable outputs.

View File

@ -6,7 +6,7 @@
import asyncio
import sys
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings
from graphrag.logger.print_progress import ProgressLogger
@ -18,7 +18,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
llm = load_llm(
"test-llm",
parameters.llm,
callbacks=NoopVerbCallbacks(),
callbacks=NoopWorkflowCallbacks(),
cache=None,
)
try:
@ -32,7 +32,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
embed_llm = load_llm_embeddings(
"test-embed-llm",
parameters.embeddings.llm,
callbacks=NoopVerbCallbacks(),
callbacks=NoopWorkflowCallbacks(),
cache=None,
)
try:

View File

@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
@ -88,7 +88,7 @@ from .generate_text_embeddings import (
all_workflows: dict[
str,
Callable[
[GraphRagConfig, PipelineRunContext, VerbCallbacks],
[GraphRagConfig, PipelineRunContext, WorkflowCallbacks],
Awaitable[pd.DataFrame | None],
],
] = {

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.compute_communities import compute_communities
@ -17,7 +17,7 @@ workflow_name = "compute_communities"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: VerbCallbacks,
_callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to create the base communities."""
base_relationship_edges = await load_table_from_storage(

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_base_text_units import (
@ -19,7 +19,7 @@ workflow_name = "create_base_text_units"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform base text_units."""
documents = await load_table_from_storage("input", context.storage)

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_final_communities import (
@ -19,7 +19,7 @@ workflow_name = "create_final_communities"
async def run_workflow(
_config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: VerbCallbacks,
_callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform final communities."""
base_entity_nodes = await load_table_from_storage(

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_final_community_reports import (
@ -19,7 +19,7 @@ workflow_name = "create_final_community_reports"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform community reports."""
nodes = await load_table_from_storage("create_final_nodes", context.storage)

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_final_covariates import (
@ -19,7 +19,7 @@ workflow_name = "create_final_covariates"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to extract and format covariates."""
text_units = await load_table_from_storage(

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_final_documents import (
@ -19,7 +19,7 @@ workflow_name = "create_final_documents"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: VerbCallbacks,
_callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform final documents."""
documents = await load_table_from_storage("input", context.storage)

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_final_entities import (
@ -19,7 +19,7 @@ workflow_name = "create_final_entities"
async def run_workflow(
_config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: VerbCallbacks,
_callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform final entities."""
base_entity_nodes = await load_table_from_storage(

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_final_nodes import (
@ -19,7 +19,7 @@ workflow_name = "create_final_nodes"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform final nodes."""
base_entity_nodes = await load_table_from_storage(

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_final_relationships import (
@ -19,7 +19,7 @@ workflow_name = "create_final_relationships"
async def run_workflow(
_config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: VerbCallbacks,
_callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform final relationships."""
base_relationship_edges = await load_table_from_storage(

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.create_final_text_units import (
@ -19,7 +19,7 @@ workflow_name = "create_final_text_units"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
_callbacks: VerbCallbacks,
_callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform the text units."""
text_units = await load_table_from_storage(

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.context import PipelineRunContext
from graphrag.index.flows.extract_graph import (
@ -21,7 +21,7 @@ workflow_name = "extract_graph"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to create the base entity graph."""
text_units = await load_table_from_storage(

View File

@ -5,7 +5,7 @@
import pandas as pd
from graphrag.callbacks.verb_callbacks import VerbCallbacks
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
from graphrag.index.context import PipelineRunContext
@ -20,7 +20,7 @@ workflow_name = "generate_text_embeddings"
async def run_workflow(
config: GraphRagConfig,
context: PipelineRunContext,
callbacks: VerbCallbacks,
callbacks: WorkflowCallbacks,
) -> pd.DataFrame | None:
"""All the steps to transform community reports."""
final_documents = await load_table_from_storage(

View File

@ -9,7 +9,7 @@ from fnllm import ChatLLM
from pydantic import TypeAdapter
import graphrag.config.defaults as defs
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.models.llm_parameters import LLMParameters
from graphrag.index.input.factory import create_input
@ -77,7 +77,7 @@ async def load_docs_in_chunks(
overlap=MIN_CHUNK_OVERLAP,
encoding_model=defs.ENCODING_MODEL,
strategy=chunk_config.strategy,
callbacks=NoopVerbCallbacks(),
callbacks=NoopWorkflowCallbacks(),
)
# Select chunks into a new df and explode it
@ -98,7 +98,7 @@ async def load_docs_in_chunks(
embedding_llm = load_llm_embeddings(
"prompt_tuning_embeddings",
llm_config,
callbacks=NoopVerbCallbacks(),
callbacks=NoopWorkflowCallbacks(),
cache=None,
)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.compute_communities import run_workflow
from graphrag.utils.storage import load_table_from_storage
@ -25,7 +25,7 @@ async def test_compute_communities():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage("base_communities", context.storage)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.create_base_text_units import run_workflow, workflow_name
from graphrag.utils.storage import load_table_from_storage
@ -25,7 +25,7 @@ async def test_create_base_text_units():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.create_final_communities import (
run_workflow,
@ -32,7 +32,7 @@ async def test_create_final_communities():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)

View File

@ -4,7 +4,7 @@
import pytest
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import LLMType
from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import (
@ -70,7 +70,7 @@ async def test_create_final_community_reports():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)
@ -105,5 +105,5 @@ async def test_create_final_community_reports_missing_llm_throws():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)

View File

@ -4,7 +4,7 @@
import pytest
from pandas.testing import assert_series_equal
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import LLMType
from graphrag.index.run.derive_from_rows import ParallelizationError
@ -46,7 +46,7 @@ async def test_create_final_covariates():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)
@ -95,5 +95,5 @@ async def test_create_final_covariates_missing_llm_throws():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.create_final_documents import (
run_workflow,
@ -28,7 +28,7 @@ async def test_create_final_documents():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)
@ -49,7 +49,7 @@ async def test_create_final_documents_with_attribute_columns():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.create_final_entities import (
run_workflow,
@ -28,7 +28,7 @@ async def test_create_final_entities():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.create_final_nodes import (
run_workflow,
@ -32,7 +32,7 @@ async def test_create_final_nodes():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)

View File

@ -2,7 +2,7 @@
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.create_final_relationships import (
run_workflow,
@ -29,7 +29,7 @@ async def test_create_final_relationships():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.index.workflows.create_final_text_units import (
run_workflow,
@ -34,7 +34,7 @@ async def test_create_final_text_units():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)
@ -60,7 +60,7 @@ async def test_create_final_text_units_no_covariates():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
actual = await load_table_from_storage(workflow_name, context.storage)

View File

@ -3,7 +3,7 @@
import pytest
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import LLMType
from graphrag.index.workflows.extract_graph import (
@ -68,7 +68,7 @@ async def test_extract_graph():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
# graph construction creates transient tables for nodes, edges, and communities
@ -110,5 +110,5 @@ async def test_extract_graph_missing_llm_throws():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
from graphrag.config.create_graphrag_config import create_graphrag_config
from graphrag.config.enums import TextEmbeddingTarget
from graphrag.index.config.embeddings import (
@ -38,7 +38,7 @@ async def test_generate_text_embeddings():
await run_workflow(
config,
context,
NoopVerbCallbacks(),
NoopWorkflowCallbacks(),
)
parquet_files = context.storage.keys()