mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-26 14:38:52 +00:00
Refactor callbacks (#1583)
* Unify Workflow and Verb callbacks interfaces * Semver * Fix storage class instantiation (#1582) --------- Co-authored-by: Josh Bradley <joshbradley@microsoft.com>
This commit is contained in:
parent
cbb8f8788e
commit
7ec9ef0261
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Simplify callbacks model."
|
||||
}
|
||||
@ -207,7 +207,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from graphrag.cache.factory import create_cache\n",
|
||||
"from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks\n",
|
||||
"from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks\n",
|
||||
"from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings\n",
|
||||
"\n",
|
||||
"# We only need to re-run the embeddings workflow, to ensure that embeddings for all required search fields are in place\n",
|
||||
@ -219,7 +219,7 @@
|
||||
"config = workflow.config\n",
|
||||
"text_embed = config.get(\"text_embed\", {})\n",
|
||||
"embedded_fields = config.get(\"embedded_fields\", {})\n",
|
||||
"callbacks = NoopVerbCallbacks()\n",
|
||||
"callbacks = NoopWorkflowCallbacks()\n",
|
||||
"cache = create_cache(pipeline_config.cache, PROJECT_DIRECTORY)\n",
|
||||
"\n",
|
||||
"await generate_text_embeddings(\n",
|
||||
|
||||
@ -13,7 +13,7 @@ Backwards compatibility is not guaranteed at this time.
|
||||
|
||||
from pydantic import PositiveInt, validate_call
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm.load_llm import load_llm
|
||||
from graphrag.logger.print_progress import PrintProgressLogger
|
||||
@ -99,7 +99,7 @@ async def generate_indexing_prompts(
|
||||
"prompt_tuning",
|
||||
config.llm,
|
||||
cache=None,
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
if not domain:
|
||||
|
||||
@ -84,7 +84,7 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
# update the blob's block count
|
||||
self._num_blocks += 1
|
||||
|
||||
def on_error(
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
@ -100,10 +100,10 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"details": details,
|
||||
})
|
||||
|
||||
def on_warning(self, message: str, details: dict | None = None):
|
||||
def warning(self, message: str, details: dict | None = None):
|
||||
"""Report a warning."""
|
||||
self._write_log({"type": "warning", "data": message, "details": details})
|
||||
|
||||
def on_log(self, message: str, details: dict | None = None):
|
||||
def log(self, message: str, details: dict | None = None):
|
||||
"""Report a generic log message."""
|
||||
self._write_log({"type": "log", "data": message, "details": details})
|
||||
|
||||
@ -9,7 +9,7 @@ from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""A logger that writes to a console."""
|
||||
|
||||
def on_error(
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
@ -19,11 +19,11 @@ class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
"""Handle when an error occurs."""
|
||||
print(message, str(cause), stack, details) # noqa T201
|
||||
|
||||
def on_warning(self, message: str, details: dict | None = None):
|
||||
def warning(self, message: str, details: dict | None = None):
|
||||
"""Handle when a warning occurs."""
|
||||
_print_warning(message)
|
||||
|
||||
def on_log(self, message: str, details: dict | None = None):
|
||||
def log(self, message: str, details: dict | None = None):
|
||||
"""Handle when a log message is produced."""
|
||||
print(message, details) # noqa T201
|
||||
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Contains the DelegatingVerbCallback definition."""
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class DelegatingVerbCallbacks(VerbCallbacks):
|
||||
"""A wrapper that implements VerbCallbacks that delegates to the underlying WorkflowCallbacks."""
|
||||
|
||||
_workflow_callbacks: WorkflowCallbacks
|
||||
_name: str
|
||||
|
||||
def __init__(self, name: str, workflow_callbacks: WorkflowCallbacks):
|
||||
"""Create a new instance of DelegatingVerbCallbacks."""
|
||||
self._workflow_callbacks = workflow_callbacks
|
||||
self._name = name
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
self._workflow_callbacks.on_step_progress(self._name, progress)
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
self._workflow_callbacks.on_error(message, cause, stack, details)
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
self._workflow_callbacks.on_warning(message, details)
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log occurs."""
|
||||
self._workflow_callbacks.on_log(message, details)
|
||||
|
||||
def measure(self, name: str, value: float, details: dict | None = None) -> None:
|
||||
"""Handle when a measurement occurs."""
|
||||
self._workflow_callbacks.on_measure(name, value, details)
|
||||
@ -25,7 +25,7 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"
|
||||
)
|
||||
|
||||
def on_error(
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
@ -50,7 +50,7 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
message = f"{message} details={details}"
|
||||
log.info(message)
|
||||
|
||||
def on_warning(self, message: str, details: dict | None = None):
|
||||
def warning(self, message: str, details: dict | None = None):
|
||||
"""Handle when a warning occurs."""
|
||||
self._out_stream.write(
|
||||
json.dumps(
|
||||
@ -61,7 +61,7 @@ class FileWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
)
|
||||
_print_warning(message)
|
||||
|
||||
def on_log(self, message: str, details: dict | None = None):
|
||||
def log(self, message: str, details: dict | None = None):
|
||||
"""Handle when a log message is produced."""
|
||||
self._out_stream.write(
|
||||
json.dumps(
|
||||
|
||||
@ -1,35 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Defines the interface for verb callbacks."""
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class NoopVerbCallbacks(VerbCallbacks):
|
||||
"""A noop implementation of the verb callbacks."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Report a progress update from the verb execution"."""
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Report a error from the verb execution."""
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Report a warning from verb execution."""
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Report an informational message from the verb execution."""
|
||||
|
||||
def measure(self, name: str, value: float) -> None:
|
||||
"""Report a telemetry measurement from the verb execution."""
|
||||
@ -3,8 +3,6 @@
|
||||
|
||||
"""A no-op implementation of WorkflowCallbacks."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
@ -12,22 +10,16 @@ from graphrag.logger.progress import Progress
|
||||
class NoopWorkflowCallbacks(WorkflowCallbacks):
|
||||
"""A no-op implementation of WorkflowCallbacks."""
|
||||
|
||||
def on_workflow_start(self, name: str, instance: object) -> None:
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
|
||||
def on_workflow_end(self, name: str, instance: object) -> None:
|
||||
def workflow_end(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow ends."""
|
||||
|
||||
def on_step_start(self, step_name: str) -> None:
|
||||
"""Execute this callback every time a step starts."""
|
||||
|
||||
def on_step_end(self, step_name: str, result: Any) -> None:
|
||||
"""Execute this callback every time a step ends."""
|
||||
|
||||
def on_step_progress(self, step_name: str, progress: Progress) -> None:
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
|
||||
def on_error(
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
@ -36,11 +28,8 @@ class NoopWorkflowCallbacks(WorkflowCallbacks):
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
|
||||
def on_warning(self, message: str, details: dict | None = None) -> None:
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
|
||||
def on_log(self, message: str, details: dict | None = None) -> None:
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
|
||||
def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
|
||||
"""Handle when a measurement occurs."""
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
|
||||
"""A workflow callback manager that emits updates."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.logger.base import ProgressLogger
|
||||
from graphrag.logger.progress import Progress
|
||||
@ -31,23 +29,14 @@ class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):
|
||||
def _latest(self) -> ProgressLogger:
|
||||
return self._progress_stack[-1]
|
||||
|
||||
def on_workflow_start(self, name: str, instance: object) -> None:
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
self._push(name)
|
||||
|
||||
def on_workflow_end(self, name: str, instance: object) -> None:
|
||||
def workflow_end(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow ends."""
|
||||
self._pop()
|
||||
|
||||
def on_step_start(self, step_name: str) -> None:
|
||||
"""Execute this callback every time a step starts."""
|
||||
self._push(f"Step {step_name}")
|
||||
self._latest(Progress(percent=0))
|
||||
|
||||
def on_step_end(self, step_name: str, result: Any) -> None:
|
||||
"""Execute this callback every time a step ends."""
|
||||
self._pop()
|
||||
|
||||
def on_step_progress(self, step_name: str, progress: Progress) -> None:
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
self._latest(progress)
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Defines the interface for verb callbacks."""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
|
||||
class VerbCallbacks(Protocol):
|
||||
"""Provides a way to report status updates from the pipeline."""
|
||||
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Report a progress update from the verb execution"."""
|
||||
...
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
details: dict | None = None,
|
||||
) -> None:
|
||||
"""Report a error from the verb execution."""
|
||||
...
|
||||
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Report a warning from verb execution."""
|
||||
...
|
||||
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Report an informational message from the verb execution."""
|
||||
...
|
||||
|
||||
def measure(self, name: str, value: float) -> None:
|
||||
"""Report a telemetry measurement from the verb execution."""
|
||||
...
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
"""Collection of callbacks that can be used to monitor the workflow execution."""
|
||||
|
||||
from typing import Any, Protocol
|
||||
from typing import Protocol
|
||||
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
@ -15,27 +15,19 @@ class WorkflowCallbacks(Protocol):
|
||||
This base class is a "noop" implementation so that clients may implement just the callbacks they need.
|
||||
"""
|
||||
|
||||
def on_workflow_start(self, name: str, instance: object) -> None:
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
...
|
||||
|
||||
def on_workflow_end(self, name: str, instance: object) -> None:
|
||||
def workflow_end(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow ends."""
|
||||
...
|
||||
|
||||
def on_step_start(self, step_name: str) -> None:
|
||||
"""Execute this callback every time a step starts."""
|
||||
...
|
||||
|
||||
def on_step_end(self, step_name: str, result: Any) -> None:
|
||||
"""Execute this callback every time a step ends."""
|
||||
...
|
||||
|
||||
def on_step_progress(self, step_name: str, progress: Progress) -> None:
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
...
|
||||
|
||||
def on_error(
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
@ -45,14 +37,10 @@ class WorkflowCallbacks(Protocol):
|
||||
"""Handle when an error occurs."""
|
||||
...
|
||||
|
||||
def on_warning(self, message: str, details: dict | None = None) -> None:
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
...
|
||||
|
||||
def on_log(self, message: str, details: dict | None = None) -> None:
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
...
|
||||
|
||||
def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
|
||||
"""Handle when a measurement occurs."""
|
||||
...
|
||||
|
||||
@ -3,8 +3,6 @@
|
||||
|
||||
"""A module containing the WorkflowCallbacks registry."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.logger.progress import Progress
|
||||
|
||||
@ -22,37 +20,25 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
|
||||
"""Register a new WorkflowCallbacks type."""
|
||||
self._callbacks.append(callbacks)
|
||||
|
||||
def on_workflow_start(self, name: str, instance: object) -> None:
|
||||
def workflow_start(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow starts."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_workflow_start"):
|
||||
callback.on_workflow_start(name, instance)
|
||||
if hasattr(callback, "workflow_start"):
|
||||
callback.workflow_start(name, instance)
|
||||
|
||||
def on_workflow_end(self, name: str, instance: object) -> None:
|
||||
def workflow_end(self, name: str, instance: object) -> None:
|
||||
"""Execute this callback when a workflow ends."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_workflow_end"):
|
||||
callback.on_workflow_end(name, instance)
|
||||
if hasattr(callback, "workflow_end"):
|
||||
callback.workflow_end(name, instance)
|
||||
|
||||
def on_step_start(self, step_name: str) -> None:
|
||||
"""Execute this callback every time a step starts."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_step_start"):
|
||||
callback.on_step_start(step_name)
|
||||
|
||||
def on_step_end(self, step_name: str, result: Any) -> None:
|
||||
"""Execute this callback every time a step ends."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_step_end"):
|
||||
callback.on_step_end(step_name, result)
|
||||
|
||||
def on_step_progress(self, step_name: str, progress: Progress) -> None:
|
||||
def progress(self, progress: Progress) -> None:
|
||||
"""Handle when progress occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_step_progress"):
|
||||
callback.on_step_progress(step_name, progress)
|
||||
if hasattr(callback, "progress"):
|
||||
callback.progress(progress)
|
||||
|
||||
def on_error(
|
||||
def error(
|
||||
self,
|
||||
message: str,
|
||||
cause: BaseException | None = None,
|
||||
@ -61,23 +47,17 @@ class WorkflowCallbacksManager(WorkflowCallbacks):
|
||||
) -> None:
|
||||
"""Handle when an error occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_error"):
|
||||
callback.on_error(message, cause, stack, details)
|
||||
if hasattr(callback, "error"):
|
||||
callback.error(message, cause, stack, details)
|
||||
|
||||
def on_warning(self, message: str, details: dict | None = None) -> None:
|
||||
def warning(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a warning occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_warning"):
|
||||
callback.on_warning(message, details)
|
||||
if hasattr(callback, "warning"):
|
||||
callback.warning(message, details)
|
||||
|
||||
def on_log(self, message: str, details: dict | None = None) -> None:
|
||||
def log(self, message: str, details: dict | None = None) -> None:
|
||||
"""Handle when a log message occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_log"):
|
||||
callback.on_log(message, details)
|
||||
|
||||
def on_measure(self, name: str, value: float, details: dict | None = None) -> None:
|
||||
"""Handle when a measurement occurs."""
|
||||
for callback in self._callbacks:
|
||||
if hasattr(callback, "on_measure"):
|
||||
callback.on_measure(name, value, details)
|
||||
if hasattr(callback, "log"):
|
||||
callback.log(message, details)
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.chunking_config import ChunkStrategyType
|
||||
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
|
||||
from graphrag.index.utils.hashing import gen_sha512_hash
|
||||
@ -16,7 +16,7 @@ from graphrag.logger.progress import Progress
|
||||
|
||||
def create_base_text_units(
|
||||
documents: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
group_by_columns: list[str],
|
||||
size: int,
|
||||
overlap: int,
|
||||
|
||||
@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.summarize_communities import (
|
||||
prepare_community_reports,
|
||||
@ -43,7 +43,7 @@ async def create_final_community_reports(
|
||||
entities: pd.DataFrame,
|
||||
communities: pd.DataFrame,
|
||||
claims_input: pd.DataFrame | None,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
summarization_strategy: dict,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
|
||||
@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.extract_covariates.extract_covariates import (
|
||||
extract_covariates,
|
||||
@ -18,7 +18,7 @@ from graphrag.index.operations.extract_covariates.extract_covariates import (
|
||||
|
||||
async def create_final_covariates(
|
||||
text_units: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
covariate_type: str,
|
||||
extraction_strategy: dict[str, Any] | None,
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.embed_graph_config import EmbedGraphConfig
|
||||
from graphrag.index.operations.compute_degree import compute_degree
|
||||
from graphrag.index.operations.create_graph import create_graph
|
||||
@ -17,7 +17,7 @@ def create_final_nodes(
|
||||
base_entity_nodes: pd.DataFrame,
|
||||
base_relationship_edges: pd.DataFrame,
|
||||
base_communities: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
embed_config: EmbedGraphConfig,
|
||||
layout_enabled: bool,
|
||||
) -> pd.DataFrame:
|
||||
|
||||
@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.extract_entities import extract_entities
|
||||
from graphrag.index.operations.summarize_descriptions import (
|
||||
@ -19,7 +19,7 @@ from graphrag.index.operations.summarize_descriptions import (
|
||||
|
||||
async def extract_graph(
|
||||
text_units: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
extraction_strategy: dict[str, Any] | None = None,
|
||||
extraction_num_threads: int = 4,
|
||||
|
||||
@ -8,7 +8,7 @@ import logging
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.config.embeddings import (
|
||||
community_full_content_embedding,
|
||||
community_summary_embedding,
|
||||
@ -32,7 +32,7 @@ async def generate_text_embeddings(
|
||||
final_text_units: pd.DataFrame | None,
|
||||
final_entities: pd.DataFrame | None,
|
||||
final_community_reports: pd.DataFrame | None,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_embed_config: dict,
|
||||
@ -110,7 +110,7 @@ async def _run_and_snapshot_embeddings(
|
||||
name: str,
|
||||
data: pd.DataFrame,
|
||||
embed_column: str,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
storage: PipelineStorage,
|
||||
text_embed_config: dict,
|
||||
|
||||
@ -30,7 +30,7 @@ from .mock_llm import MockChatLLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.typing import ErrorHandlerFn
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -105,7 +105,7 @@ def load_llm(
|
||||
name: str,
|
||||
config: LLMParameters,
|
||||
*,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache | None,
|
||||
chat_only=False,
|
||||
) -> ChatLLM:
|
||||
@ -135,7 +135,7 @@ def load_llm_embeddings(
|
||||
name: str,
|
||||
llm_config: LLMParameters,
|
||||
*,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache | None,
|
||||
chat_only=False,
|
||||
) -> EmbeddingsLLM:
|
||||
@ -160,7 +160,7 @@ def load_llm_embeddings(
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _create_error_handler(callbacks: VerbCallbacks) -> ErrorHandlerFn:
|
||||
def _create_error_handler(callbacks: WorkflowCallbacks) -> ErrorHandlerFn:
|
||||
def on_error(
|
||||
error: BaseException | None = None,
|
||||
stack: str | None = None,
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Any, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.chunking_config import ChunkingConfig, ChunkStrategyType
|
||||
from graphrag.index.operations.chunk_text.typing import (
|
||||
ChunkInput,
|
||||
@ -23,7 +23,7 @@ def chunk_text(
|
||||
overlap: int,
|
||||
encoding_model: str,
|
||||
strategy: ChunkStrategyType,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.Series:
|
||||
"""
|
||||
Chunk a piece of text into smaller pieces.
|
||||
|
||||
@ -11,7 +11,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingStrategy
|
||||
from graphrag.utils.embeddings import create_collection_name
|
||||
from graphrag.vector_stores.base import BaseVectorStore, VectorStoreDocument
|
||||
@ -37,7 +37,7 @@ class TextEmbedStrategyType(str, Enum):
|
||||
|
||||
async def embed_text(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
embed_column: str,
|
||||
strategy: dict,
|
||||
@ -109,7 +109,7 @@ async def embed_text(
|
||||
|
||||
async def _text_embed_in_memory(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
embed_column: str,
|
||||
strategy: dict,
|
||||
@ -126,7 +126,7 @@ async def _text_embed_in_memory(
|
||||
|
||||
async def _text_embed_with_vector_store(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
embed_column: str,
|
||||
strategy: dict[str, Any],
|
||||
|
||||
@ -8,14 +8,14 @@ from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult
|
||||
from graphrag.logger.progress import ProgressTicker, progress_ticker
|
||||
|
||||
|
||||
async def run( # noqa RUF029 async is required for interface
|
||||
input: list[str],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
_args: dict[str, Any],
|
||||
) -> TextEmbeddingResult:
|
||||
|
||||
@ -13,7 +13,7 @@ from pydantic import TypeAdapter
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.index.llm.load_llm import load_llm_embeddings
|
||||
from graphrag.index.operations.embed_text.strategies.typing import TextEmbeddingResult
|
||||
@ -26,7 +26,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
async def run(
|
||||
input: list[str],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: dict[str, Any],
|
||||
) -> TextEmbeddingResult:
|
||||
@ -75,7 +75,7 @@ def _get_splitter(config: LLMParameters, batch_max_tokens: int) -> TokenTextSpli
|
||||
|
||||
def _get_llm(
|
||||
config: LLMParameters,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
) -> EmbeddingsLLM:
|
||||
return load_llm_embeddings(
|
||||
|
||||
@ -7,7 +7,7 @@ from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -20,7 +20,7 @@ class TextEmbeddingResult:
|
||||
TextEmbeddingStrategy = Callable[
|
||||
[
|
||||
list[str],
|
||||
VerbCallbacks,
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
dict,
|
||||
],
|
||||
|
||||
@ -12,7 +12,7 @@ import pandas as pd
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.index.operations.extract_covariates.claim_extractor import ClaimExtractor
|
||||
@ -30,7 +30,7 @@ DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
|
||||
async def extract_covariates(
|
||||
input: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
column: str,
|
||||
covariate_type: str,
|
||||
@ -78,7 +78,7 @@ async def run_claim_extraction(
|
||||
input: str | Iterable[str],
|
||||
entity_types: list[str],
|
||||
resolved_entities_map: dict[str, str],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy_config: dict[str, Any],
|
||||
) -> CovariateExtractionResult:
|
||||
|
||||
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -41,7 +41,7 @@ CovariateExtractStrategy = Callable[
|
||||
Iterable[str],
|
||||
list[str],
|
||||
dict[str, str],
|
||||
VerbCallbacks,
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
dict[str, Any],
|
||||
],
|
||||
|
||||
@ -9,7 +9,7 @@ from typing import Any
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.bootstrap import bootstrap
|
||||
from graphrag.index.operations.extract_entities.typing import (
|
||||
@ -27,7 +27,7 @@ DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"]
|
||||
|
||||
async def extract_entities(
|
||||
text_units: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
text_column: str,
|
||||
id_column: str,
|
||||
|
||||
@ -8,7 +8,7 @@ from fnllm import ChatLLM
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.index.operations.extract_entities.graph_extractor import GraphExtractor
|
||||
from graphrag.index.operations.extract_entities.typing import (
|
||||
@ -22,7 +22,7 @@ from graphrag.index.operations.extract_entities.typing import (
|
||||
async def run_graph_intelligence(
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
@ -36,7 +36,7 @@ async def run_extract_entities(
|
||||
llm: ChatLLM,
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: VerbCallbacks | None,
|
||||
callbacks: WorkflowCallbacks | None,
|
||||
args: StrategyConfig,
|
||||
) -> EntityExtractionResult:
|
||||
"""Run the entity extraction chain."""
|
||||
|
||||
@ -8,7 +8,7 @@ import nltk
|
||||
from nltk.corpus import words
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.operations.extract_entities.typing import (
|
||||
Document,
|
||||
EntityExtractionResult,
|
||||
@ -23,7 +23,7 @@ words.ensure_loaded()
|
||||
async def run( # noqa RUF029 async is required for interface
|
||||
docs: list[Document],
|
||||
entity_types: EntityTypes,
|
||||
callbacks: VerbCallbacks, # noqa ARG001
|
||||
callbacks: WorkflowCallbacks, # noqa ARG001
|
||||
cache: PipelineCache, # noqa ARG001
|
||||
args: StrategyConfig, # noqa ARG001
|
||||
) -> EntityExtractionResult:
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any
|
||||
import networkx as nx
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
ExtractedEntity = dict[str, Any]
|
||||
ExtractedRelationship = dict[str, Any]
|
||||
@ -40,7 +40,7 @@ EntityExtractStrategy = Callable[
|
||||
[
|
||||
list[Document],
|
||||
EntityTypes,
|
||||
VerbCallbacks,
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
StrategyConfig,
|
||||
],
|
||||
|
||||
@ -6,14 +6,14 @@
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.operations.embed_graph.typing import NodeEmbeddings
|
||||
from graphrag.index.operations.layout_graph.typing import GraphLayout
|
||||
|
||||
|
||||
def layout_graph(
|
||||
graph: nx.Graph,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
enabled: bool,
|
||||
embeddings: NodeEmbeddings | None,
|
||||
):
|
||||
@ -58,7 +58,7 @@ def _run_layout(
|
||||
graph: nx.Graph,
|
||||
enabled: bool,
|
||||
embeddings: NodeEmbeddings,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> GraphLayout:
|
||||
if enabled:
|
||||
from graphrag.index.operations.layout_graph.umap import (
|
||||
|
||||
@ -8,7 +8,7 @@ import logging
|
||||
import pandas as pd
|
||||
|
||||
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.sort_context import (
|
||||
parallel_sort_context_batch,
|
||||
)
|
||||
@ -24,7 +24,7 @@ def prepare_community_reports(
|
||||
nodes,
|
||||
edges,
|
||||
claims,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
max_tokens: int = 16_000,
|
||||
):
|
||||
"""Prep communities for report generation."""
|
||||
|
||||
@ -9,7 +9,7 @@ import traceback
|
||||
from fnllm import ChatLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import (
|
||||
CommunityReportsExtractor,
|
||||
@ -28,7 +28,7 @@ async def run_graph_intelligence(
|
||||
community: str | int,
|
||||
input: str,
|
||||
level: int,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: StrategyConfig,
|
||||
) -> CommunityReport | None:
|
||||
@ -44,7 +44,7 @@ async def _run_extractor(
|
||||
input: str,
|
||||
level: int,
|
||||
args: StrategyConfig,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> CommunityReport | None:
|
||||
# RateLimiter
|
||||
rate_limiter = RateLimiter(rate=1, per=60)
|
||||
|
||||
@ -10,8 +10,8 @@ import pandas as pd
|
||||
import graphrag.config.defaults as defaults
|
||||
import graphrag.index.operations.summarize_communities.community_reports_extractor.schemas as schemas
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
|
||||
prep_community_report_context,
|
||||
@ -34,7 +34,7 @@ async def summarize_communities(
|
||||
local_contexts,
|
||||
nodes,
|
||||
community_hierarchy,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy: dict,
|
||||
async_mode: AsyncType = AsyncType.AsyncIO,
|
||||
@ -73,7 +73,7 @@ async def summarize_communities(
|
||||
local_reports = await derive_from_rows(
|
||||
level_contexts,
|
||||
run_generate,
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
num_threads=num_threads,
|
||||
async_type=async_mode,
|
||||
)
|
||||
@ -84,7 +84,7 @@ async def summarize_communities(
|
||||
|
||||
async def _generate_report(
|
||||
runner: CommunityReportsStrategy,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy: dict,
|
||||
community_id: int,
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Any
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
ExtractedEntity = dict[str, Any]
|
||||
StrategyConfig = dict[str, Any]
|
||||
@ -45,7 +45,7 @@ CommunityReportsStrategy = Callable[
|
||||
str | int,
|
||||
str,
|
||||
int,
|
||||
VerbCallbacks,
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
StrategyConfig,
|
||||
],
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
from fnllm import ChatLLM
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.llm.load_llm import load_llm, read_llm_params
|
||||
from graphrag.index.operations.summarize_descriptions.description_summary_extractor import (
|
||||
SummarizeExtractor,
|
||||
@ -20,7 +20,7 @@ from graphrag.index.operations.summarize_descriptions.typing import (
|
||||
async def run_graph_intelligence(
|
||||
id: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
args: StrategyConfig,
|
||||
) -> SummarizedDescriptionResult:
|
||||
@ -36,7 +36,7 @@ async def run_summarize_descriptions(
|
||||
llm: ChatLLM,
|
||||
id: str | tuple[str, str],
|
||||
descriptions: list[str],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
args: StrategyConfig,
|
||||
) -> SummarizedDescriptionResult:
|
||||
"""Run the entity extraction chain."""
|
||||
|
||||
@ -10,7 +10,7 @@ from typing import Any
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.index.operations.summarize_descriptions.typing import (
|
||||
SummarizationStrategy,
|
||||
SummarizeStrategyType,
|
||||
@ -23,7 +23,7 @@ log = logging.getLogger(__name__)
|
||||
async def summarize_descriptions(
|
||||
entities_df: pd.DataFrame,
|
||||
relationships_df: pd.DataFrame,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
cache: PipelineCache,
|
||||
strategy: dict[str, Any] | None = None,
|
||||
num_threads: int = 4,
|
||||
|
||||
@ -9,7 +9,7 @@ from enum import Enum
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
|
||||
StrategyConfig = dict[str, Any]
|
||||
|
||||
@ -26,7 +26,7 @@ SummarizationStrategy = Callable[
|
||||
[
|
||||
str | tuple[str, str],
|
||||
list[str],
|
||||
VerbCallbacks,
|
||||
WorkflowCallbacks,
|
||||
PipelineCache,
|
||||
StrategyConfig,
|
||||
],
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import Any, TypeVar, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.enums import AsyncType
|
||||
from graphrag.logger.progress import progress_ticker
|
||||
|
||||
@ -32,7 +32,7 @@ class ParallelizationError(ValueError):
|
||||
async def derive_from_rows(
|
||||
input: pd.DataFrame,
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
num_threads: int = 4,
|
||||
async_type: AsyncType = AsyncType.AsyncIO,
|
||||
) -> list[ItemType | None]:
|
||||
@ -57,7 +57,7 @@ async def derive_from_rows(
|
||||
async def derive_from_rows_asyncio_threads(
|
||||
input: pd.DataFrame,
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
num_threads: int | None = 4,
|
||||
) -> list[ItemType | None]:
|
||||
"""
|
||||
@ -87,7 +87,7 @@ async def derive_from_rows_asyncio_threads(
|
||||
async def derive_from_rows_asyncio(
|
||||
input: pd.DataFrame,
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
num_threads: int = 4,
|
||||
) -> list[ItemType | None]:
|
||||
"""
|
||||
@ -121,7 +121,7 @@ GatherFn = Callable[[ExecuteFn], Awaitable[list[ItemType | None]]]
|
||||
async def _derive_from_rows_base(
|
||||
input: pd.DataFrame,
|
||||
transform: Callable[[pd.Series], Awaitable[ItemType]],
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
gather: GatherFn[ItemType],
|
||||
) -> list[ItemType | None]:
|
||||
"""
|
||||
|
||||
@ -9,15 +9,13 @@ import time
|
||||
import traceback
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import asdict
|
||||
from typing import cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.factory import CacheFactory
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
|
||||
from graphrag.callbacks.delegating_verb_callbacks import DelegatingVerbCallbacks
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunStats
|
||||
@ -119,7 +117,7 @@ async def run_workflows(
|
||||
update_storage=update_index_storage,
|
||||
config=config,
|
||||
cache=cache,
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
progress_logger=progress_logger,
|
||||
)
|
||||
|
||||
@ -163,16 +161,15 @@ async def _run_workflows(
|
||||
last_workflow = workflow
|
||||
run_workflow = all_workflows[workflow]
|
||||
progress = logger.child(workflow, transient=False)
|
||||
callbacks.on_workflow_start(workflow, None)
|
||||
verb_callbacks = DelegatingVerbCallbacks(workflow, callbacks)
|
||||
callbacks.workflow_start(workflow, None)
|
||||
work_time = time.time()
|
||||
result = await run_workflow(
|
||||
config,
|
||||
context,
|
||||
verb_callbacks,
|
||||
callbacks,
|
||||
)
|
||||
progress(Progress(percent=1))
|
||||
callbacks.on_workflow_end(workflow, result)
|
||||
callbacks.workflow_end(workflow, result)
|
||||
yield PipelineRunResult(workflow, result, None)
|
||||
|
||||
context.stats.workflows[workflow] = {"overall": time.time() - work_time}
|
||||
@ -186,9 +183,7 @@ async def _run_workflows(
|
||||
|
||||
except Exception as e:
|
||||
log.exception("error running workflow %s", last_workflow)
|
||||
cast("WorkflowCallbacks", callbacks).on_error(
|
||||
"Error running pipeline!", e, traceback.format_exc()
|
||||
)
|
||||
callbacks.error("Error running pipeline!", e, traceback.format_exc())
|
||||
yield PipelineRunResult(last_workflow, None, [e])
|
||||
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import (
|
||||
run_graph_intelligence as run_entity_summarization,
|
||||
@ -92,7 +92,7 @@ async def _run_entity_summarization(
|
||||
entities_df: pd.DataFrame,
|
||||
config: GraphRagConfig,
|
||||
cache: PipelineCache,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame:
|
||||
"""Run entity summarization.
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.cache.pipeline_cache import PipelineCache
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
|
||||
from graphrag.index.flows.generate_text_embeddings import generate_text_embeddings
|
||||
@ -86,7 +86,7 @@ async def update_dataframe_outputs(
|
||||
update_storage: PipelineStorage,
|
||||
config: GraphRagConfig,
|
||||
cache: PipelineCache,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
progress_logger: ProgressLogger,
|
||||
) -> None:
|
||||
"""Update the mergeable outputs.
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.llm.load_llm import load_llm, load_llm_embeddings
|
||||
from graphrag.logger.print_progress import ProgressLogger
|
||||
@ -18,7 +18,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
llm = load_llm(
|
||||
"test-llm",
|
||||
parameters.llm,
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
try:
|
||||
@ -32,7 +32,7 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
|
||||
embed_llm = load_llm_embeddings(
|
||||
"test-embed-llm",
|
||||
parameters.embeddings.llm,
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
try:
|
||||
|
||||
@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
|
||||
@ -88,7 +88,7 @@ from .generate_text_embeddings import (
|
||||
all_workflows: dict[
|
||||
str,
|
||||
Callable[
|
||||
[GraphRagConfig, PipelineRunContext, VerbCallbacks],
|
||||
[GraphRagConfig, PipelineRunContext, WorkflowCallbacks],
|
||||
Awaitable[pd.DataFrame | None],
|
||||
],
|
||||
] = {
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.compute_communities import compute_communities
|
||||
@ -17,7 +17,7 @@ workflow_name = "compute_communities"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: VerbCallbacks,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to create the base communities."""
|
||||
base_relationship_edges = await load_table_from_storage(
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_base_text_units import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_base_text_units"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform base text_units."""
|
||||
documents = await load_table_from_storage("input", context.storage)
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_final_communities import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_final_communities"
|
||||
async def run_workflow(
|
||||
_config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: VerbCallbacks,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform final communities."""
|
||||
base_entity_nodes = await load_table_from_storage(
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_final_community_reports import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_final_community_reports"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform community reports."""
|
||||
nodes = await load_table_from_storage("create_final_nodes", context.storage)
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_final_covariates import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_final_covariates"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to extract and format covariates."""
|
||||
text_units = await load_table_from_storage(
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_final_documents import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_final_documents"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: VerbCallbacks,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform final documents."""
|
||||
documents = await load_table_from_storage("input", context.storage)
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_final_entities import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_final_entities"
|
||||
async def run_workflow(
|
||||
_config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: VerbCallbacks,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform final entities."""
|
||||
base_entity_nodes = await load_table_from_storage(
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_final_nodes import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_final_nodes"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform final nodes."""
|
||||
base_entity_nodes = await load_table_from_storage(
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_final_relationships import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_final_relationships"
|
||||
async def run_workflow(
|
||||
_config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: VerbCallbacks,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform final relationships."""
|
||||
base_relationship_edges = await load_table_from_storage(
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.create_final_text_units import (
|
||||
@ -19,7 +19,7 @@ workflow_name = "create_final_text_units"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
_callbacks: VerbCallbacks,
|
||||
_callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform the text units."""
|
||||
text_units = await load_table_from_storage(
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
from graphrag.index.flows.extract_graph import (
|
||||
@ -21,7 +21,7 @@ workflow_name = "extract_graph"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to create the base entity graph."""
|
||||
text_units = await load_table_from_storage(
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.callbacks.verb_callbacks import VerbCallbacks
|
||||
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.index.config.embeddings import get_embedded_fields, get_embedding_settings
|
||||
from graphrag.index.context import PipelineRunContext
|
||||
@ -20,7 +20,7 @@ workflow_name = "generate_text_embeddings"
|
||||
async def run_workflow(
|
||||
config: GraphRagConfig,
|
||||
context: PipelineRunContext,
|
||||
callbacks: VerbCallbacks,
|
||||
callbacks: WorkflowCallbacks,
|
||||
) -> pd.DataFrame | None:
|
||||
"""All the steps to transform community reports."""
|
||||
final_documents = await load_table_from_storage(
|
||||
|
||||
@ -9,7 +9,7 @@ from fnllm import ChatLLM
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
import graphrag.config.defaults as defs
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.models.llm_parameters import LLMParameters
|
||||
from graphrag.index.input.factory import create_input
|
||||
@ -77,7 +77,7 @@ async def load_docs_in_chunks(
|
||||
overlap=MIN_CHUNK_OVERLAP,
|
||||
encoding_model=defs.ENCODING_MODEL,
|
||||
strategy=chunk_config.strategy,
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
# Select chunks into a new df and explode it
|
||||
@ -98,7 +98,7 @@ async def load_docs_in_chunks(
|
||||
embedding_llm = load_llm_embeddings(
|
||||
"prompt_tuning_embeddings",
|
||||
llm_config,
|
||||
callbacks=NoopVerbCallbacks(),
|
||||
callbacks=NoopWorkflowCallbacks(),
|
||||
cache=None,
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.compute_communities import run_workflow
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
@ -25,7 +25,7 @@ async def test_compute_communities():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage("base_communities", context.storage)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.create_base_text_units import run_workflow, workflow_name
|
||||
from graphrag.utils.storage import load_table_from_storage
|
||||
@ -25,7 +25,7 @@ async def test_create_base_text_units():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.create_final_communities import (
|
||||
run_workflow,
|
||||
@ -32,7 +32,7 @@ async def test_create_final_communities():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.operations.summarize_communities.community_reports_extractor.community_reports_extractor import (
|
||||
@ -70,7 +70,7 @@ async def test_create_final_community_reports():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
@ -105,5 +105,5 @@ async def test_create_final_community_reports_missing_llm_throws():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
from pandas.testing import assert_series_equal
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.run.derive_from_rows import ParallelizationError
|
||||
@ -46,7 +46,7 @@ async def test_create_final_covariates():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
@ -95,5 +95,5 @@ async def test_create_final_covariates_missing_llm_throws():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.create_final_documents import (
|
||||
run_workflow,
|
||||
@ -28,7 +28,7 @@ async def test_create_final_documents():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
@ -49,7 +49,7 @@ async def test_create_final_documents_with_attribute_columns():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.create_final_entities import (
|
||||
run_workflow,
|
||||
@ -28,7 +28,7 @@ async def test_create_final_entities():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.create_final_nodes import (
|
||||
run_workflow,
|
||||
@ -32,7 +32,7 @@ async def test_create_final_nodes():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License
|
||||
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.create_final_relationships import (
|
||||
run_workflow,
|
||||
@ -29,7 +29,7 @@ async def test_create_final_relationships():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.index.workflows.create_final_text_units import (
|
||||
run_workflow,
|
||||
@ -34,7 +34,7 @@ async def test_create_final_text_units():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
@ -60,7 +60,7 @@ async def test_create_final_text_units_no_covariates():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
actual = await load_table_from_storage(workflow_name, context.storage)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import LLMType
|
||||
from graphrag.index.workflows.extract_graph import (
|
||||
@ -68,7 +68,7 @@ async def test_extract_graph():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
# graph construction creates transient tables for nodes, edges, and communities
|
||||
@ -110,5 +110,5 @@ async def test_extract_graph_missing_llm_throws():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from graphrag.callbacks.noop_verb_callbacks import NoopVerbCallbacks
|
||||
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
|
||||
from graphrag.config.create_graphrag_config import create_graphrag_config
|
||||
from graphrag.config.enums import TextEmbeddingTarget
|
||||
from graphrag.index.config.embeddings import (
|
||||
@ -38,7 +38,7 @@ async def test_generate_text_embeddings():
|
||||
await run_workflow(
|
||||
config,
|
||||
context,
|
||||
NoopVerbCallbacks(),
|
||||
NoopWorkflowCallbacks(),
|
||||
)
|
||||
|
||||
parquet_files = context.storage.keys()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user