Reorganize python package structure (#1214)

This commit is contained in:
Josh Bradley 2024-10-10 17:01:42 -04:00 committed by GitHub
parent ce8749bd19
commit d9a005c9b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
48 changed files with 370 additions and 305 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Reorganized api,reporter,callback code into separate components. Defined debug profiles."
}

39
.vscode/launch.json vendored
View File

@ -1,12 +1,39 @@
{
"_comment": "Use this file to configure the graphrag project for debugging. You may create other configuration profiles based on these or select one below to use.",
"version": "0.2.0",
"configurations": [
{
"name": "Attach to Node Functions",
"type": "node",
"request": "attach",
"port": 9229,
"preLaunchTask": "func: host start"
"name": "Indexer",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "index",
"--root", "<path_to_ragtest_root_demo>"
],
},
{
"name": "Query",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "query",
"--root", "<path_to_ragtest_root_demo>",
"--method", "global",
"What are the top themes in this story",
]
},
{
"name": "Prompt Tuning",
"type": "debugpy",
"request": "launch",
"module": "poetry",
"args": [
"poe", "prompt_tune",
"--config",
"<path_to_ragtest_root_demo>/settings.yaml",
]
}
]
}
}

View File

@ -38,7 +38,6 @@
],
"python.defaultInterpreterPath": "python/services/.venv/bin/python",
"python.languageServer": "Pylance",
"python.analysis.typeCheckingMode": "basic",
"cSpell.customDictionaries": {
"project-words": {
"name": "project-words",

30
graphrag/api/__init__.py Normal file
View File

@ -0,0 +1,30 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""API for GraphRAG.
WARNING: This API is under development and may undergo changes in future releases.
Backwards compatibility is not guaranteed at this time.
"""
from .index_api import build_index
from .prompt_tune_api import DocSelectionType, generate_indexing_prompts
from .query_api import (
global_search,
global_search_streaming,
local_search,
local_search_streaming,
)
__all__ = [ # noqa: RUF022
# index API
"build_index",
# query API
"global_search",
"global_search_streaming",
"local_search",
"local_search_streaming",
# prompt tuning API
"DocSelectionType",
"generate_indexing_prompts",
]

View File

@ -9,15 +9,12 @@ Backwards compatibility is not guaranteed at this time.
"""
from graphrag.config import CacheType, GraphRagConfig
from .cache.noop_pipeline_cache import NoopPipelineCache
from .create_pipeline_config import create_pipeline_config
from .emit.types import TableEmitterType
from .progress import (
ProgressReporter,
)
from .run import run_pipeline_with_config
from .typing import PipelineRunResult
from graphrag.index.cache.noop_pipeline_cache import NoopPipelineCache
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.emit.types import TableEmitterType
from graphrag.index.run import run_pipeline_with_config
from graphrag.index.typing import PipelineRunResult
from graphrag.logging import ProgressReporter
async def build_index(

View File

@ -16,9 +16,8 @@ from pydantic import PositiveInt, validate_call
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.llm import load_llm
from graphrag.index.progress import PrintProgressReporter
from .generator import (
from graphrag.logging import PrintProgressReporter
from graphrag.prompt_tune.generator import (
MAX_TOKEN_COUNT,
create_community_summarization_prompt,
create_entity_extraction_prompt,
@ -31,11 +30,11 @@ from .generator import (
generate_entity_types,
generate_persona,
)
from .loader import (
from graphrag.prompt_tune.loader import (
MIN_CHUNK_SIZE,
load_docs_in_chunks,
)
from .types import DocSelectionType
from graphrag.prompt_tune.types import DocSelectionType
@validate_call

View File

@ -25,21 +25,20 @@ import pandas as pd
from pydantic import validate_call
from graphrag.config import GraphRagConfig
from graphrag.index.progress.types import PrintProgressReporter
from graphrag.logging import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType
from .factories import get_global_search_engine, get_local_search_engine
from .indexer_adapters import (
from graphrag.query.factories import get_global_search_engine, get_local_search_engine
from graphrag.query.indexer_adapters import (
read_indexer_covariates,
read_indexer_entities,
read_indexer_relationships,
read_indexer_reports,
read_indexer_text_units,
)
from .input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.typing import VectorStoreFactory, VectorStoreType
reporter = PrintProgressReporter("")

View File

@ -0,0 +1,4 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A module containing callback implementations."""

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A reporter that writes to a blob storage."""
"""A logger that emits updates from the indexing engine to a blob in Azure Storage."""
import json
from datetime import datetime, timezone
@ -14,7 +14,7 @@ from datashaper import NoopWorkflowCallbacks
class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
"""A reporter that writes to a blob storage."""
"""A logger that writes to a blob storage account."""
_blob_service_client: BlobServiceClient
_container_name: str

View File

@ -1,13 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Console-based reporter for the workflow engine."""
"""A logger that emits updates from the indexing engine to the console."""
from datashaper import NoopWorkflowCallbacks
class ConsoleWorkflowCallbacks(NoopWorkflowCallbacks):
"""A reporter that writes to a console."""
"""A logger that writes to a console."""
def on_error(
self,

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Load pipeline reporter method."""
"""Create a pipeline reporter."""
from pathlib import Path
from typing import cast
@ -20,7 +20,7 @@ from .console_workflow_callbacks import ConsoleWorkflowCallbacks
from .file_workflow_callbacks import FileWorkflowCallbacks
def load_pipeline_reporter(
def create_pipeline_reporter(
config: PipelineReportingConfig | None, root_dir: str | None
) -> WorkflowCallbacks:
"""Create a reporter for the given pipeline config."""

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A reporter that writes to a file."""
"""A logger that emits updates from the indexing engine to a local file."""
import json
import logging
@ -14,12 +14,12 @@ log = logging.getLogger(__name__)
class FileWorkflowCallbacks(NoopWorkflowCallbacks):
"""A reporter that writes to a file."""
"""A logger that writes to a local file."""
_out_stream: TextIOWrapper
def __init__(self, directory: str):
"""Create a new file-based workflow reporter."""
"""Create a new file-based workflow logger."""
Path(directory).mkdir(parents=True, exist_ok=True)
self._out_stream = open( # noqa: PTH123, SIM115
Path(directory) / "logs.json", "a", encoding="utf-8", errors="strict"

View File

@ -3,9 +3,10 @@
"""GlobalSearch LLM Callbacks."""
from graphrag.query.llm.base import BaseLLMCallback
from graphrag.query.structured_search.base import SearchResult
from .llm_callbacks import BaseLLMCallback
class GlobalSearchLLMCallback(BaseLLMCallback):
"""GlobalSearch LLM Callbacks."""

View File

@ -0,0 +1,15 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""LLM Callbacks."""
class BaseLLMCallback:
"""Base class for LLM callbacks."""
def __init__(self):
self.response = []
def on_llm_new_token(self, token: str):
"""Handle when a new token is generated."""
self.response.append(token)

View File

@ -1,13 +1,13 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A workflow callback manager that emits updates to a ProgressReporter."""
"""A workflow callback manager that emits updates."""
from typing import Any
from datashaper import ExecutionNode, NoopWorkflowCallbacks, Progress, TableContainer
from graphrag.index.progress import ProgressReporter
from graphrag.logging import ProgressReporter
class ProgressWorkflowCallbacks(NoopWorkflowCallbacks):

View File

@ -5,11 +5,11 @@
import argparse
from graphrag.logging import ReporterType
from graphrag.utils.cli import dir_exist, file_exist
from .cli import index_cli
from .emit.types import TableEmitterType
from .progress.types import ReporterType
if __name__ == "__main__":
parser = argparse.ArgumentParser(

View File

@ -11,22 +11,21 @@ import time
import warnings
from pathlib import Path
import graphrag.api as api
from graphrag.config import (
CacheType,
enable_logging_with_config,
load_config,
resolve_paths,
)
from graphrag.logging import ProgressReporter, ReporterType, create_progress_reporter
from .api import build_index
from .emit.types import TableEmitterType
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
from .graph.extractors.community_reports.prompts import COMMUNITY_REPORT_PROMPT
from .graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT
from .graph.extractors.summarize.prompts import SUMMARIZE_PROMPT
from .init_content import INIT_DOTENV, INIT_YAML
from .progress import ProgressReporter, ReporterType
from .progress.load_progress_reporter import load_progress_reporter
from .validate_config import validate_config_names
# Ignore warnings from numba
@ -118,7 +117,7 @@ def index_cli(
output_dir: str | None,
):
"""Run the pipeline with the given config."""
progress_reporter = load_progress_reporter(reporter)
progress_reporter = create_progress_reporter(reporter)
info, error, success = _logger(progress_reporter)
run_id = resume or update_index_id or time.strftime("%Y%m%d-%H%M%S")
@ -161,7 +160,7 @@ def index_cli(
_register_signal_handlers(progress_reporter)
outputs = asyncio.run(
build_index(
api.build_index(
config=config,
run_id=run_id,
is_resume_run=bool(resume),

View File

@ -11,9 +11,9 @@ from typing import cast
import pandas as pd
from graphrag.index.config import PipelineCSVInputConfig, PipelineInputConfig
from graphrag.index.progress import ProgressReporter
from graphrag.index.storage import PipelineStorage
from graphrag.index.utils import gen_md5_hash
from graphrag.logging import ProgressReporter
log = logging.getLogger(__name__)

View File

@ -12,11 +12,11 @@ import pandas as pd
from graphrag.config import InputConfig, InputType
from graphrag.index.config import PipelineInputConfig
from graphrag.index.progress import NullProgressReporter, ProgressReporter
from graphrag.index.storage import (
BlobPipelineStorage,
FilePipelineStorage,
)
from graphrag.logging import NullProgressReporter, ProgressReporter
from .csv import input_type as csv
from .csv import load as load_csv

View File

@ -11,9 +11,9 @@ from typing import Any
import pandas as pd
from graphrag.index.config import PipelineInputConfig
from graphrag.index.progress import ProgressReporter
from graphrag.index.storage import PipelineStorage
from graphrag.index.utils import gen_md5_hash
from graphrag.logging import ProgressReporter
DEFAULT_FILE_PATTERN = re.compile(
r".*[\\/](?P<source>[^\\/]+)[\\/](?P<year>\d{4})-(?P<month>\d{2})-(?P<day>\d{2})_(?P<author>[^_]+)_\d+\.txt"

View File

@ -1,18 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Progress-reporting components."""
from .types import (
NullProgressReporter,
PrintProgressReporter,
ProgressReporter,
ReporterType,
)
__all__ = [
"NullProgressReporter",
"PrintProgressReporter",
"ProgressReporter",
"ReporterType",
]

View File

@ -1,141 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Types for status reporting."""
from abc import ABC, abstractmethod
from enum import Enum
from datashaper import Progress
class ReporterType(Enum):
"""The type of reporter to use."""
RICH = "rich"
PRINT = "print"
NONE = "none"
def __str__(self):
"""Return the string representation of the enum value."""
return self.value
class ProgressReporter(ABC):
"""
Abstract base class for progress reporters.
This is used to report workflow processing progress via mechanisms like progress-bars.
"""
@abstractmethod
def __call__(self, update: Progress):
"""Update progress."""
@abstractmethod
def dispose(self):
"""Dispose of the progress reporter."""
@abstractmethod
def child(self, prefix: str, transient=True) -> "ProgressReporter":
"""Create a child progress bar."""
@abstractmethod
def force_refresh(self) -> None:
"""Force a refresh."""
@abstractmethod
def stop(self) -> None:
"""Stop the progress reporter."""
@abstractmethod
def error(self, message: str) -> None:
"""Report an error."""
@abstractmethod
def warning(self, message: str) -> None:
"""Report a warning."""
@abstractmethod
def info(self, message: str) -> None:
"""Report information."""
@abstractmethod
def success(self, message: str) -> None:
"""Report success."""
class NullProgressReporter(ProgressReporter):
"""A progress reporter that does nothing."""
def __call__(self, update: Progress) -> None:
"""Update progress."""
def dispose(self) -> None:
"""Dispose of the progress reporter."""
def child(self, prefix: str, transient: bool = True) -> ProgressReporter:
"""Create a child progress bar."""
return self
def force_refresh(self) -> None:
"""Force a refresh."""
def stop(self) -> None:
"""Stop the progress reporter."""
def error(self, message: str) -> None:
"""Report an error."""
def warning(self, message: str) -> None:
"""Report a warning."""
def info(self, message: str) -> None:
"""Report information."""
def success(self, message: str) -> None:
"""Report success."""
class PrintProgressReporter(ProgressReporter):
"""A progress reporter that does nothing."""
prefix: str
def __init__(self, prefix: str):
"""Create a new progress reporter."""
self.prefix = prefix
print(f"\n{self.prefix}", end="") # noqa T201
def __call__(self, update: Progress) -> None:
"""Update progress."""
print(".", end="") # noqa T201
def dispose(self) -> None:
"""Dispose of the progress reporter."""
def child(self, prefix: str, transient: bool = True) -> "ProgressReporter":
"""Create a child progress bar."""
return PrintProgressReporter(prefix)
def stop(self) -> None:
"""Stop the progress reporter."""
def force_refresh(self) -> None:
"""Force a refresh."""
def error(self, message: str) -> None:
"""Report an error."""
print(f"\n{self.prefix}ERROR: {message}") # noqa T201
def warning(self, message: str) -> None:
"""Report a warning."""
print(f"\n{self.prefix}WARNING: {message}") # noqa T201
def info(self, message: str) -> None:
"""Report information."""
print(f"\n{self.prefix}INFO: {message}") # noqa T201
def success(self, message: str) -> None:
"""Report success."""
print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201

View File

@ -1,18 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Reporting utilities and implementations for the indexing engine."""
from .blob_workflow_callbacks import BlobWorkflowCallbacks
from .console_workflow_callbacks import ConsoleWorkflowCallbacks
from .file_workflow_callbacks import FileWorkflowCallbacks
from .load_pipeline_reporter import load_pipeline_reporter
from .progress_workflow_callbacks import ProgressWorkflowCallbacks
__all__ = [
"BlobWorkflowCallbacks",
"ConsoleWorkflowCallbacks",
"FileWorkflowCallbacks",
"ProgressWorkflowCallbacks",
"load_pipeline_reporter",
]

View File

@ -13,6 +13,7 @@ from typing import cast
import pandas as pd
from datashaper import WorkflowCallbacks
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
from graphrag.index.cache import PipelineCache
from graphrag.index.config import (
PipelineConfig,
@ -21,10 +22,6 @@ from graphrag.index.config import (
)
from graphrag.index.emit import TableEmitterType, create_table_emitters
from graphrag.index.load_pipeline_config import load_pipeline_config
from graphrag.index.progress import NullProgressReporter, ProgressReporter
from graphrag.index.reporting import (
ConsoleWorkflowCallbacks,
)
from graphrag.index.run.cache import _create_cache
from graphrag.index.run.postprocess import (
_create_postprocess_steps,
@ -52,6 +49,10 @@ from graphrag.index.workflows import (
WorkflowDefinitions,
load_workflows,
)
from graphrag.logging import (
NullProgressReporter,
ProgressReporter,
)
from graphrag.utils.storage import _create_storage
log = logging.getLogger(__name__)

View File

@ -12,6 +12,7 @@ from datashaper import (
WorkflowCallbacks,
)
from graphrag.callbacks.factories import create_pipeline_reporter
from graphrag.index.cache.memory_pipeline_cache import InMemoryCache
from graphrag.index.cache.pipeline_cache import PipelineCache
from graphrag.index.config.cache import (
@ -31,10 +32,9 @@ from graphrag.index.config.storage import (
)
from graphrag.index.context import PipelineRunContext, PipelineRunStats
from graphrag.index.input import load_input
from graphrag.index.progress.types import ProgressReporter
from graphrag.index.reporting import load_pipeline_reporter
from graphrag.index.storage.memory_pipeline_storage import MemoryPipelineStorage
from graphrag.index.storage.typing import PipelineStorage
from graphrag.logging import ProgressReporter
log = logging.getLogger(__name__)
@ -43,7 +43,7 @@ def _create_reporter(
config: PipelineReportingConfigTypes | None, root_dir: str
) -> WorkflowCallbacks | None:
"""Create the reporter for the pipeline."""
return load_pipeline_reporter(config, root_dir) if config else None
return create_pipeline_reporter(config, root_dir) if config else None
async def _create_input(

View File

@ -15,15 +15,13 @@ from datashaper import (
WorkflowCallbacksManager,
)
from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks
from graphrag.index.context import PipelineRunContext
from graphrag.index.emit.table_emitter import TableEmitter
from graphrag.index.progress.types import ProgressReporter
from graphrag.index.reporting.progress_workflow_callbacks import (
ProgressWorkflowCallbacks,
)
from graphrag.index.run.profiling import _write_workflow_stats
from graphrag.index.storage.typing import PipelineStorage
from graphrag.index.typing import PipelineRunResult
from graphrag.logging import ProgressReporter
from graphrag.utils.storage import _load_table_from_storage
log = logging.getLogger(__name__)

View File

@ -13,7 +13,7 @@ from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient
from datashaper import Progress
from graphrag.index.progress import ProgressReporter
from graphrag.logging import ProgressReporter
from .typing import PipelineStorage

View File

@ -16,7 +16,7 @@ from aiofiles.os import remove
from aiofiles.ospath import exists
from datashaper import Progress
from graphrag.index.progress import ProgressReporter
from graphrag.logging import ProgressReporter
from .typing import PipelineStorage

View File

@ -8,7 +8,7 @@ from abc import ABCMeta, abstractmethod
from collections.abc import Iterator
from typing import Any
from graphrag.index.progress import ProgressReporter
from graphrag.logging import ProgressReporter
class PipelineStorage(metaclass=ABCMeta):

View File

@ -10,9 +10,7 @@ from datashaper import NoopVerbCallbacks
from graphrag.config.models import GraphRagConfig
from graphrag.index.llm import load_llm, load_llm_embeddings
from graphrag.index.progress import (
ProgressReporter,
)
from graphrag.logging import ProgressReporter
def validate_config_names(

View File

@ -0,0 +1,27 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Logging utilities and implementations."""
from .console import ConsoleReporter
from .factories import create_progress_reporter
from .null_progress import NullProgressReporter
from .print_progress import PrintProgressReporter
from .rich_progress import RichProgressReporter
from .types import (
ProgressReporter,
ReporterType,
StatusLogger,
)
__all__ = [
# Progress Reporters
"ConsoleReporter",
"NullProgressReporter",
"PrintProgressReporter",
"ProgressReporter",
"ReporterType",
"RichProgressReporter",
"StatusLogger",
"create_progress_reporter",
]

View File

@ -1,29 +1,14 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Status Reporter for orchestration."""
"""Console Reporter."""
from abc import ABCMeta, abstractmethod
from typing import Any
class StatusReporter(metaclass=ABCMeta):
"""Provides a way to report status updates from the pipeline."""
@abstractmethod
def error(self, message: str, details: dict[str, Any] | None = None):
"""Report an error."""
@abstractmethod
def warning(self, message: str, details: dict[str, Any] | None = None):
"""Report a warning."""
@abstractmethod
def log(self, message: str, details: dict[str, Any] | None = None):
"""Report a log."""
from .types import StatusLogger
class ConsoleStatusReporter(StatusReporter):
class ConsoleReporter(StatusLogger):
"""A reporter that writes to a console."""
def error(self, message: str, details: dict[str, Any] | None = None):

View File

@ -1,18 +1,18 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Load a progress reporter."""
"""Factory functions for creating loggers."""
from .rich import RichProgressReporter
from .null_progress import NullProgressReporter
from .print_progress import PrintProgressReporter
from .rich_progress import RichProgressReporter
from .types import (
NullProgressReporter,
PrintProgressReporter,
ProgressReporter,
ReporterType,
)
def load_progress_reporter(
def create_progress_reporter(
reporter_type: ReporterType = ReporterType.NONE,
) -> ProgressReporter:
"""Load a progress reporter.

View File

@ -0,0 +1,38 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Null Progress Reporter."""
from .types import Progress, ProgressReporter
class NullProgressReporter(ProgressReporter):
"""A progress reporter that does nothing."""
def __call__(self, update: Progress) -> None:
"""Update progress."""
def dispose(self) -> None:
"""Dispose of the progress reporter."""
def child(self, prefix: str, transient: bool = True) -> ProgressReporter:
"""Create a child progress bar."""
return self
def force_refresh(self) -> None:
"""Force a refresh."""
def stop(self) -> None:
"""Stop the progress reporter."""
def error(self, message: str) -> None:
"""Report an error."""
def warning(self, message: str) -> None:
"""Report a warning."""
def info(self, message: str) -> None:
"""Report information."""
def success(self, message: str) -> None:
"""Report success."""

View File

@ -0,0 +1,50 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Print Progress Reporter."""
from .types import Progress, ProgressReporter
class PrintProgressReporter(ProgressReporter):
"""A progress reporter that does nothing."""
prefix: str
def __init__(self, prefix: str):
"""Create a new progress reporter."""
self.prefix = prefix
print(f"\n{self.prefix}", end="") # noqa T201
def __call__(self, update: Progress) -> None:
"""Update progress."""
print(".", end="") # noqa T201
def dispose(self) -> None:
"""Dispose of the progress reporter."""
def child(self, prefix: str, transient: bool = True) -> "ProgressReporter":
"""Create a child progress bar."""
return PrintProgressReporter(prefix)
def stop(self) -> None:
"""Stop the progress reporter."""
def force_refresh(self) -> None:
"""Force a refresh."""
def error(self, message: str) -> None:
"""Report an error."""
print(f"\n{self.prefix}ERROR: {message}") # noqa T201
def warning(self, message: str) -> None:
"""Report a warning."""
print(f"\n{self.prefix}WARNING: {message}") # noqa T201
def info(self, message: str) -> None:
"""Report information."""
print(f"\n{self.prefix}INFO: {message}") # noqa T201
def success(self, message: str) -> None:
"""Report success."""
print(f"\n{self.prefix}SUCCESS: {message}") # noqa T201

82
graphrag/logging/types.py Normal file
View File

@ -0,0 +1,82 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Types for status reporting."""
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any
from datashaper import Progress
class ReporterType(Enum):
"""The type of reporter to use."""
RICH = "rich"
PRINT = "print"
NONE = "none"
def __str__(self):
"""Return the string representation of the enum value."""
return self.value
class StatusLogger(ABC):
"""Provides a way to report status updates from the pipeline."""
@abstractmethod
def error(self, message: str, details: dict[str, Any] | None = None):
"""Report an error."""
@abstractmethod
def warning(self, message: str, details: dict[str, Any] | None = None):
"""Report a warning."""
@abstractmethod
def log(self, message: str, details: dict[str, Any] | None = None):
"""Report a log."""
class ProgressReporter(ABC):
"""
Abstract base class for progress reporters.
This is used to report workflow processing progress via mechanisms like progress-bars.
"""
@abstractmethod
def __call__(self, update: Progress):
"""Update progress."""
@abstractmethod
def dispose(self):
"""Dispose of the progress reporter."""
@abstractmethod
def child(self, prefix: str, transient=True) -> "ProgressReporter":
"""Create a child progress bar."""
@abstractmethod
def force_refresh(self) -> None:
"""Force a refresh."""
@abstractmethod
def stop(self) -> None:
"""Stop the progress reporter."""
@abstractmethod
def error(self, message: str) -> None:
"""Report an error."""
@abstractmethod
def warning(self, message: str) -> None:
"""Report a warning."""
@abstractmethod
def info(self, message: str) -> None:
"""Report information."""
@abstractmethod
def success(self, message: str) -> None:
"""Report success."""

View File

@ -6,9 +6,9 @@
import argparse
import asyncio
from graphrag.api import DocSelectionType
from graphrag.utils.cli import dir_exist, file_exist
from .api import DocSelectionType
from .cli import prompt_tune
from .generator import MAX_TOKEN_COUNT
from .loader import MIN_CHUNK_SIZE

View File

@ -5,21 +5,20 @@
from pathlib import Path
import graphrag.api as api
from graphrag.config import load_config
from graphrag.index.progress import PrintProgressReporter
from graphrag.logging import PrintProgressReporter
from . import api
from .generator.community_report_summarization import COMMUNITY_SUMMARIZATION_FILENAME
from .generator.entity_extraction_prompt import ENTITY_EXTRACTION_FILENAME
from .generator.entity_summarization_prompt import ENTITY_SUMMARIZATION_FILENAME
from .types import DocSelectionType
async def prompt_tune(
config: str,
root: str,
domain: str,
selection_method: DocSelectionType,
selection_method: api.DocSelectionType,
limit: int,
max_tokens: int,
chunk_size: int,

View File

@ -12,8 +12,8 @@ from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.index.input import load_input
from graphrag.index.llm import load_llm_embeddings
from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.progress.types import ProgressReporter
from graphrag.llm.types.llm_types import EmbeddingLLM
from graphrag.logging import ProgressReporter
from graphrag.prompt_tune.types import DocSelectionType
MIN_CHUNK_OVERLAP = 0

View File

@ -9,13 +9,12 @@ from pathlib import Path
import pandas as pd
import graphrag.api as api
from graphrag.config import GraphRagConfig, load_config, resolve_paths
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.index.progress import PrintProgressReporter
from graphrag.logging import PrintProgressReporter
from graphrag.utils.storage import _create_storage, _load_table_from_storage
from . import api
reporter = PrintProgressReporter("")

View File

@ -7,16 +7,7 @@ from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Generator
from typing import Any
class BaseLLMCallback:
"""Base class for LLM callbacks."""
def __init__(self):
self.response = []
def on_llm_new_token(self, token: str):
"""Handle when a new token is generated."""
self.response.append(token)
from graphrag.callbacks.llm_callbacks import BaseLLMCallback
class BaseLLM(ABC):

View File

@ -8,9 +8,9 @@ from collections.abc import Callable
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from graphrag.logging import ConsoleReporter, StatusLogger
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.typing import OpenaiApiType
from graphrag.query.progress import ConsoleStatusReporter, StatusReporter
class BaseOpenAILLM(ABC):
@ -87,7 +87,7 @@ class BaseOpenAILLM(ABC):
class OpenAILLMImpl(BaseOpenAILLM):
"""Orchestration OpenAI LLM Implementation."""
_reporter: StatusReporter = ConsoleStatusReporter()
_reporter: StatusLogger = ConsoleReporter()
def __init__(
self,
@ -100,7 +100,7 @@ class OpenAILLMImpl(BaseOpenAILLM):
organization: str | None = None,
max_retries: int = 10,
request_timeout: float = 180.0,
reporter: StatusReporter | None = None,
reporter: StatusLogger | None = None,
):
self.api_key = api_key
self.azure_ad_token_provider = azure_ad_token_provider
@ -111,7 +111,7 @@ class OpenAILLMImpl(BaseOpenAILLM):
self.organization = organization
self.max_retries = max_retries
self.request_timeout = request_timeout
self.reporter = reporter or ConsoleStatusReporter()
self.reporter = reporter or ConsoleReporter()
try:
# Create OpenAI sync and async clients
@ -181,7 +181,7 @@ class OpenAILLMImpl(BaseOpenAILLM):
class OpenAITextEmbeddingImpl(BaseTextEmbedding):
"""Orchestration OpenAI Text Embedding Implementation."""
_reporter: StatusReporter | None = None
_reporter: StatusLogger | None = None
def _create_openai_client(self, api_type: OpenaiApiType):
"""Create a new synchronous and asynchronous OpenAI client instance."""

View File

@ -15,13 +15,13 @@ from tenacity import (
wait_exponential_jitter,
)
from graphrag.logging import StatusLogger
from graphrag.query.llm.base import BaseLLM, BaseLLMCallback
from graphrag.query.llm.oai.base import OpenAILLMImpl
from graphrag.query.llm.oai.typing import (
OPENAI_RETRY_ERROR_TYPES,
OpenaiApiType,
)
from graphrag.query.progress import StatusReporter
_MODEL_REQUIRED_MSG = "model is required"
@ -42,7 +42,7 @@ class ChatOpenAI(BaseLLM, OpenAILLMImpl):
max_retries: int = 10,
request_timeout: float = 180.0,
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
reporter: StatusReporter | None = None,
reporter: StatusLogger | None = None,
):
OpenAILLMImpl.__init__(
self=self,

View File

@ -18,6 +18,7 @@ from tenacity import (
wait_exponential_jitter,
)
from graphrag.logging import StatusLogger
from graphrag.query.llm.base import BaseTextEmbedding
from graphrag.query.llm.oai.base import OpenAILLMImpl
from graphrag.query.llm.oai.typing import (
@ -25,7 +26,6 @@ from graphrag.query.llm.oai.typing import (
OpenaiApiType,
)
from graphrag.query.llm.text_utils import chunk_text
from graphrag.query.progress import StatusReporter
class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
@ -46,7 +46,7 @@ class OpenAIEmbedding(BaseTextEmbedding, OpenAILLMImpl):
max_retries: int = 10,
request_timeout: float = 180.0,
retry_error_types: tuple[type[BaseException]] = OPENAI_RETRY_ERROR_TYPES, # type: ignore
reporter: StatusReporter | None = None,
reporter: StatusLogger | None = None,
):
OpenAILLMImpl.__init__(
self=self,

View File

@ -14,6 +14,7 @@ from typing import Any
import pandas as pd
import tiktoken
from graphrag.callbacks.global_search_callbacks import GlobalSearchLLMCallback
from graphrag.llm.openai.utils import try_parse_json_object
from graphrag.query.context_builder.builders import GlobalContextBuilder
from graphrag.query.context_builder.conversation_history import (
@ -22,9 +23,6 @@ from graphrag.query.context_builder.conversation_history import (
from graphrag.query.llm.base import BaseLLM
from graphrag.query.llm.text_utils import num_tokens
from graphrag.query.structured_search.base import BaseSearch, SearchResult
from graphrag.query.structured_search.global_search.callbacks import (
GlobalSearchLLMCallback,
)
from graphrag.query.structured_search.global_search.map_system_prompt import (
MAP_SYSTEM_PROMPT,
)

View File

@ -1,7 +1,7 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""A package containing vector-storage implementations."""
"""A module containing vector storage implementations."""
from .azure_ai_search import AzureAISearch
from .base import BaseVectorStore, VectorStoreDocument, VectorStoreSearchResult

View File

@ -108,9 +108,9 @@ requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
build-backend = "poetry_dynamic_versioning.backend"
[tool.poe.tasks]
_sort_imports = "ruff check --select I --fix . --preview"
_format_code = "ruff format . --preview"
_ruff_check = 'ruff check . --preview'
_sort_imports = "ruff check --select I --fix ."
_format_code = "ruff format ."
_ruff_check = 'ruff check .'
_pyright = "pyright"
_convert_local_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/notebooks/ --output="{notebook_name}_nb" --template=docsite/nbdocsite_template --to markdown examples_notebooks/local_search.ipynb'
_convert_global_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/notebooks/ --output="{notebook_name}_nb" --template=docsite/nbdocsite_template --to markdown examples_notebooks/global_search.ipynb'
@ -119,9 +119,9 @@ _semversioner_changelog = "semversioner changelog > CHANGELOG.md"
_semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)"
semversioner_add = "semversioner add-change"
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
check_format = 'ruff format . --check --preview'
fix = "ruff --preview check --fix ."
fix_unsafe = "ruff check --preview --fix --unsafe-fixes ."
check_format = 'ruff format . --check'
fix = "ruff check --fix ."
fix_unsafe = "ruff check --fix --unsafe-fixes ."
_test_all = "coverage run -m pytest ./tests"
test_unit = "pytest ./tests/unit"
@ -164,10 +164,12 @@ target-version = "py310"
extend-include = ["*.ipynb"]
[tool.ruff.format]
preview = true
docstring-code-format = true
docstring-code-line-length = 20
[tool.ruff.lint]
preview = true
select = [
"E4",
"E7",