mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Index API (#953)
* Initial Index API - Implement main API entry point: build_index - Rely on GraphRagConfig instead of PipelineConfig - This unifies the API signature with the promt_tune and query API entry points - Derive cache settings, config, and resuming from the config and other arguments to simplify/reduce arguments to build_index - Add preflight config file validations - Add semver change * fix smoke tests * fix smoke tests * Use asyncio * Add e2e artifacts in GH actions * Remove unnecessary E2E test, and add skip_validations flag to cli * Nicer imports * Reorganize API functions. * Add license headers and module docstrings * Fix ignored ruff rule --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
This commit is contained in:
parent
5a781dd234
commit
6b4de3d841
5
.github/workflows/python-smoke-tests.yml
vendored
5
.github/workflows/python-smoke-tests.yml
vendored
@ -102,8 +102,3 @@ jobs:
|
||||
with:
|
||||
name: smoke-test-artifacts-${{ matrix.python-version }}-${{ matrix.poetry-version }}-${{ runner.os }}
|
||||
path: tests/fixtures/*/output
|
||||
|
||||
- name: E2E Test
|
||||
if: steps.changes.outputs.python == 'true'
|
||||
run: |
|
||||
./scripts/e2e-test.sh
|
||||
|
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "minor",
|
||||
"description": "Implement Index API"
|
||||
}
|
184
graphrag/config/config_file_loader.py
Normal file
184
graphrag/config/config_file_loader.py
Normal file
@ -0,0 +1,184 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Load a GraphRagConfiguration from a file."""
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from . import create_graphrag_config
|
||||
from .models.graph_rag_config import GraphRagConfig
|
||||
|
||||
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
|
||||
|
||||
|
||||
def resolve_config_path_with_root(root: str | Path) -> Path:
|
||||
"""Resolve the config path from the given root directory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root : str | Path
|
||||
The path to the root directory containing the config file.
|
||||
Searches for a default config file (settings.{yaml,yml,json}).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The resolved config file path.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file is not found or cannot be resolved for the directory.
|
||||
"""
|
||||
root = Path(root)
|
||||
|
||||
if not root.is_dir():
|
||||
msg = f"Invalid config path: {root} is not a directory"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
for file in _default_config_files:
|
||||
if (root / file).is_file():
|
||||
return root / file
|
||||
|
||||
msg = f"Unable to resolve config file for parent directory: {root}"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
|
||||
class ConfigFileLoader(ABC):
|
||||
"""Base class for loading a configuration from a file."""
|
||||
|
||||
@abstractmethod
|
||||
def load_config(self, config_path: str | Path) -> GraphRagConfig:
|
||||
"""Load configuration from a file."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConfigYamlLoader(ConfigFileLoader):
|
||||
"""Load a configuration from a yaml file."""
|
||||
|
||||
def load_config(self, config_path: str | Path) -> GraphRagConfig:
|
||||
"""Load a configuration from a yaml file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str | Path
|
||||
The path to the yaml file to load.
|
||||
|
||||
Returns
|
||||
-------
|
||||
GraphRagConfig
|
||||
The loaded configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the file extension is not .yaml or .yml.
|
||||
FileNotFoundError
|
||||
If the config file is not found.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
if config_path.suffix not in [".yaml", ".yml"]:
|
||||
msg = f"Invalid file extension for loading yaml config from: {config_path!s}. Expected .yaml or .yml"
|
||||
raise ValueError(msg)
|
||||
root_dir = str(config_path.parent)
|
||||
if not config_path.is_file():
|
||||
msg = f"Config file not found: {config_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
with config_path.open("rb") as file:
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root_dir)
|
||||
|
||||
|
||||
class ConfigJsonLoader(ConfigFileLoader):
|
||||
"""Load a configuration from a json file."""
|
||||
|
||||
def load_config(self, config_path: str | Path) -> GraphRagConfig:
|
||||
"""Load a configuration from a json file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str | Path
|
||||
The path to the json file to load.
|
||||
|
||||
Returns
|
||||
-------
|
||||
GraphRagConfig
|
||||
The loaded configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the file extension is not .json.
|
||||
FileNotFoundError
|
||||
If the config file is not found.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
root_dir = str(config_path.parent)
|
||||
if config_path.suffix != ".json":
|
||||
msg = f"Invalid file extension for loading json config from: {config_path!s}. Expected .json"
|
||||
raise ValueError(msg)
|
||||
if not config_path.is_file():
|
||||
msg = f"Config file not found: {config_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
with config_path.open("rb") as file:
|
||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root_dir)
|
||||
|
||||
|
||||
def get_config_file_loader(config_path: str | Path) -> ConfigFileLoader:
|
||||
"""Config File Loader Factory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str | Path
|
||||
The path to the config file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ConfigFileLoader
|
||||
The config file loader for the provided config file.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the config file extension is not supported.
|
||||
"""
|
||||
config_path = Path(config_path)
|
||||
ext = config_path.suffix
|
||||
match ext:
|
||||
case ".yaml" | ".yml":
|
||||
return ConfigYamlLoader()
|
||||
case ".json":
|
||||
return ConfigJsonLoader()
|
||||
case _:
|
||||
msg = f"Unsupported config file extension: {ext}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def load_config_from_file(config_path: str | Path) -> GraphRagConfig:
|
||||
"""Load a configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : str | Path
|
||||
The path to the configuration file.
|
||||
Supports .yaml, .yml, and .json config files.
|
||||
|
||||
Returns
|
||||
-------
|
||||
GraphRagConfig
|
||||
The loaded configuration.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the file extension is not supported.
|
||||
FileNotFoundError
|
||||
If the config file is not found.
|
||||
"""
|
||||
loader = get_config_file_loader(config_path)
|
||||
return loader.load_config(config_path)
|
65
graphrag/config/logging.py
Normal file
65
graphrag/config/logging.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Logging utilities. A unified way for enabling logging."""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from .enums import ReportingType
|
||||
from .models.graph_rag_config import GraphRagConfig
|
||||
from .resolve_timestamp_path import resolve_timestamp_path
|
||||
|
||||
|
||||
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
|
||||
"""Enable logging to a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
log_filepath : str | Path
|
||||
The path to the log file.
|
||||
verbose : bool, default=False
|
||||
Whether to log debug messages.
|
||||
"""
|
||||
log_filepath = Path(log_filepath)
|
||||
log_filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_filepath.touch(exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filepath,
|
||||
filemode="a",
|
||||
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
level=logging.DEBUG if verbose else logging.INFO,
|
||||
)
|
||||
|
||||
|
||||
def enable_logging_with_config(
|
||||
config: GraphRagConfig, timestamp_value: str, verbose: bool = False
|
||||
) -> tuple[bool, str]:
|
||||
"""Enable logging to a file based on the config.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : GraphRagConfig
|
||||
The configuration.
|
||||
timestamp_value : str
|
||||
The timestamp value representing the directory to place the log files.
|
||||
verbose : bool, default=False
|
||||
Whether to log debug messages.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[bool, str]
|
||||
A tuple of a boolean indicating if logging was enabled and the path to the log file.
|
||||
(False, "") if logging was not enabled.
|
||||
(True, str) if logging was enabled.
|
||||
"""
|
||||
if config.reporting.type == ReportingType.file:
|
||||
log_path = resolve_timestamp_path(
|
||||
Path(config.root_dir) / config.reporting.base_dir / "indexing-engine.log",
|
||||
timestamp_value,
|
||||
)
|
||||
enable_logging(log_path, verbose)
|
||||
return (True, str(log_path))
|
||||
return (False, "")
|
115
graphrag/config/resolve_timestamp_path.py
Normal file
115
graphrag/config/resolve_timestamp_path.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Resolve timestamp variables in a path."""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from string import Template
|
||||
|
||||
|
||||
def _resolve_timestamp_path_with_value(path: str | Path, timestamp_value: str) -> Path:
|
||||
"""Resolve the timestamp in the path with the given timestamp value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str | Path
|
||||
The path containing ${timestamp} variables to resolve.
|
||||
timestamp_value : str
|
||||
The timestamp value used to resolve the path.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The path with ${timestamp} variables resolved to the provided timestamp value.
|
||||
"""
|
||||
template = Template(str(path))
|
||||
resolved_path = template.substitute(timestamp=timestamp_value)
|
||||
return Path(resolved_path)
|
||||
|
||||
|
||||
def _resolve_timestamp_path_with_dir(
|
||||
path: str | Path, pattern: re.Pattern[str]
|
||||
) -> Path:
|
||||
"""Resolve the timestamp in the path with the latest available timestamp directory value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str | Path
|
||||
The path containing ${timestamp} variables to resolve.
|
||||
pattern : re.Pattern[str]
|
||||
The pattern to use to match the timestamp directories.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The path with ${timestamp} variables resolved to the latest available timestamp directory value.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the parent directory expecting to contain timestamp directories does not exist or is not a directory.
|
||||
Or if no timestamp directories are found in the parent directory that match the pattern.
|
||||
"""
|
||||
path = Path(path)
|
||||
path_parts = path.parts
|
||||
parent_dir = Path(path_parts[0])
|
||||
found_timestamp_pattern = False
|
||||
for _, part in enumerate(path_parts[1:]):
|
||||
if part.lower() == "${timestamp}":
|
||||
found_timestamp_pattern = True
|
||||
break
|
||||
parent_dir = parent_dir / part
|
||||
|
||||
# Path not using timestamp layout.
|
||||
if not found_timestamp_pattern:
|
||||
return path
|
||||
|
||||
if not parent_dir.exists() or not parent_dir.is_dir():
|
||||
msg = f"Parent directory {parent_dir} does not exist or is not a directory."
|
||||
raise ValueError(msg)
|
||||
|
||||
timestamp_dirs = [
|
||||
d for d in parent_dir.iterdir() if d.is_dir() and pattern.match(d.name)
|
||||
]
|
||||
timestamp_dirs.sort(key=lambda d: d.name, reverse=True)
|
||||
if len(timestamp_dirs) == 0:
|
||||
msg = f"No timestamp directories found in {parent_dir} that match {pattern.pattern}."
|
||||
raise ValueError(msg)
|
||||
return _resolve_timestamp_path_with_value(path, timestamp_dirs[0].name)
|
||||
|
||||
|
||||
def resolve_timestamp_path(
|
||||
path: str | Path,
|
||||
pattern_or_timestamp_value: re.Pattern[str] | str = re.compile(r"^\d{8}-\d{6}$"),
|
||||
) -> Path:
|
||||
r"""Timestamp path resolver.
|
||||
|
||||
Resolve the timestamp in the path with the given timestamp value or
|
||||
with the latest available timestamp directory matching the given pattern.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str | Path
|
||||
The path containing ${timestamp} variables to resolve.
|
||||
pattern_or_timestamp_value : re.Pattern[str] | str, default=re.compile(r"^\d{8}-\d{6}$")
|
||||
The pattern to use to match the timestamp directories or the timestamp value to use.
|
||||
If a string is provided, the path will be resolved with the given string value.
|
||||
Otherwise, the path will be resolved with the latest available timestamp directory
|
||||
that matches the given pattern.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The path with ${timestamp} variables resolved to the provided timestamp value or
|
||||
the latest available timestamp directory.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the parent directory expecting to contain timestamp directories does not exist or is not a directory.
|
||||
Or if no timestamp directories are found in the parent directory that match the pattern.
|
||||
"""
|
||||
if isinstance(pattern_or_timestamp_value, str):
|
||||
return _resolve_timestamp_path_with_value(path, pattern_or_timestamp_value)
|
||||
return _resolve_timestamp_path_with_dir(path, pattern_or_timestamp_value)
|
@ -68,6 +68,11 @@ if __name__ == "__main__":
|
||||
help="Overlay default configuration values on a provided configuration file (--config).",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-validations",
|
||||
help="Skip any preflight validation. Useful when running no LLM steps.",
|
||||
action="store_true",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.overlay_defaults and not args.config:
|
||||
@ -85,5 +90,5 @@ if __name__ == "__main__":
|
||||
dryrun=args.dryrun or False,
|
||||
init=args.init or False,
|
||||
overlay_defaults=args.overlay_defaults or False,
|
||||
cli=True,
|
||||
skip_validations=args.skip_validations or False,
|
||||
)
|
||||
|
79
graphrag/index/api.py
Normal file
79
graphrag/index/api.py
Normal file
@ -0,0 +1,79 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""
|
||||
Indexing API for GraphRAG.
|
||||
|
||||
WARNING: This API is under development and may undergo changes in future releases.
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
||||
|
||||
from .cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from .create_pipeline_config import create_pipeline_config
|
||||
from .emit.types import TableEmitterType
|
||||
from .progress import (
|
||||
ProgressReporter,
|
||||
)
|
||||
from .run import run_pipeline_with_config
|
||||
from .typing import PipelineRunResult
|
||||
|
||||
|
||||
async def build_index(
|
||||
config: GraphRagConfig,
|
||||
run_id: str,
|
||||
memory_profile: bool,
|
||||
progress_reporter: ProgressReporter | None = None,
|
||||
emit: list[str] | None = None,
|
||||
) -> list[PipelineRunResult]:
|
||||
"""Run the pipeline with the given configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : PipelineConfig
|
||||
The configuration.
|
||||
run_id : str
|
||||
The run id. Creates a output directory with this name.
|
||||
memory_profile : bool
|
||||
Whether to enable memory profiling.
|
||||
progress_reporter : ProgressReporter | None default=None
|
||||
The progress reporter.
|
||||
emit : list[str] | None default=None
|
||||
The list of emitter types to emit.
|
||||
Accepted values {"parquet", "csv"}.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[PipelineRunResult]
|
||||
The list of pipeline run results
|
||||
"""
|
||||
try:
|
||||
resolve_timestamp_path(config.storage.base_dir, run_id)
|
||||
resume = True
|
||||
except ValueError as _:
|
||||
resume = False
|
||||
pipeline_config = create_pipeline_config(config)
|
||||
pipeline_cache = (
|
||||
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
|
||||
)
|
||||
outputs: list[PipelineRunResult] = []
|
||||
async for output in run_pipeline_with_config(
|
||||
pipeline_config,
|
||||
run_id=run_id,
|
||||
memory_profile=memory_profile,
|
||||
cache=pipeline_cache,
|
||||
progress_reporter=progress_reporter,
|
||||
emit=([TableEmitterType(e) for e in emit] if emit is not None else None),
|
||||
is_resume_run=resume,
|
||||
):
|
||||
outputs.append(output)
|
||||
if progress_reporter:
|
||||
if output.errors and len(output.errors) > 0:
|
||||
progress_reporter.error(output.workflow)
|
||||
else:
|
||||
progress_reporter.success(output.workflow)
|
||||
progress_reporter.info(str(output.result))
|
||||
return outputs
|
@ -6,32 +6,28 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config import (
|
||||
create_graphrag_config,
|
||||
from graphrag.config import create_graphrag_config
|
||||
from graphrag.config.config_file_loader import (
|
||||
load_config_from_file,
|
||||
resolve_config_path_with_root,
|
||||
)
|
||||
from graphrag.index import PipelineConfig, create_pipeline_config
|
||||
from graphrag.index.cache import NoopPipelineCache
|
||||
from graphrag.index.progress import (
|
||||
NullProgressReporter,
|
||||
PrintProgressReporter,
|
||||
ProgressReporter,
|
||||
)
|
||||
from graphrag.index.progress.rich import RichProgressReporter
|
||||
from graphrag.index.run import run_pipeline_with_config
|
||||
from graphrag.index.validate_config import validate_config_names
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.config.logging import enable_logging_with_config
|
||||
|
||||
from .emit import TableEmitterType
|
||||
from .api import build_index
|
||||
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
|
||||
from .graph.extractors.community_reports.prompts import COMMUNITY_REPORT_PROMPT
|
||||
from .graph.extractors.graph.prompts import GRAPH_EXTRACTION_PROMPT
|
||||
from .graph.extractors.summarize.prompts import SUMMARIZE_PROMPT
|
||||
from .init_content import INIT_DOTENV, INIT_YAML
|
||||
from .progress import ProgressReporter
|
||||
from .progress.load_progress_reporter import load_progress_reporter
|
||||
from .validate_config import validate_config_names
|
||||
|
||||
# Ignore warnings from numba
|
||||
warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
|
||||
@ -39,7 +35,7 @@ warnings.filterwarnings("ignore", message=".*NumbaDeprecationWarning.*")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def redact(input: dict) -> str:
|
||||
def _redact(input: dict) -> str:
|
||||
"""Sanitize the config json."""
|
||||
|
||||
# Redact any sensitive configuration
|
||||
@ -56,7 +52,7 @@ def redact(input: dict) -> str:
|
||||
"organization",
|
||||
}:
|
||||
if value is not None:
|
||||
result[key] = f"REDACTED, length {len(value)}"
|
||||
result[key] = "==== REDACTED ===="
|
||||
elif isinstance(value, dict):
|
||||
result[key] = redact_dict(value)
|
||||
elif isinstance(value, list):
|
||||
@ -69,6 +65,43 @@ def redact(input: dict) -> str:
|
||||
return json.dumps(redacted_dict, indent=4)
|
||||
|
||||
|
||||
def _logger(reporter: ProgressReporter):
|
||||
def info(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
if verbose:
|
||||
reporter.info(msg)
|
||||
|
||||
def error(msg: str, verbose: bool = False):
|
||||
log.error(msg)
|
||||
if verbose:
|
||||
reporter.error(msg)
|
||||
|
||||
def success(msg: str, verbose: bool = False):
|
||||
log.info(msg)
|
||||
if verbose:
|
||||
reporter.success(msg)
|
||||
|
||||
return info, error, success
|
||||
|
||||
|
||||
def _register_signal_handlers(reporter: ProgressReporter):
|
||||
import signal
|
||||
|
||||
def handle_signal(signum, _):
|
||||
# Handle the signal here
|
||||
reporter.info(f"Received signal {signum}, exiting...")
|
||||
reporter.dispose()
|
||||
for task in asyncio.all_tasks():
|
||||
task.cancel()
|
||||
reporter.info("All tasks cancelled. Exiting...")
|
||||
|
||||
# Register signal handlers for SIGINT and SIGHUP
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
|
||||
if sys.platform != "win32":
|
||||
signal.signal(signal.SIGHUP, handle_signal)
|
||||
|
||||
|
||||
def index_cli(
|
||||
root: str,
|
||||
init: bool,
|
||||
@ -81,98 +114,81 @@ def index_cli(
|
||||
emit: str | None,
|
||||
dryrun: bool,
|
||||
overlay_defaults: bool,
|
||||
cli: bool = False,
|
||||
skip_validations: bool,
|
||||
):
|
||||
"""Run the pipeline with the given config."""
|
||||
progress_reporter = load_progress_reporter(reporter or "rich")
|
||||
info, error, success = _logger(progress_reporter)
|
||||
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
|
||||
_enable_logging(root, run_id, verbose)
|
||||
progress_reporter = _get_progress_reporter(reporter)
|
||||
|
||||
if init:
|
||||
_initialize_project_at(root, progress_reporter)
|
||||
sys.exit(0)
|
||||
if overlay_defaults:
|
||||
pipeline_config: str | PipelineConfig = _create_default_config(
|
||||
root, config, verbose, dryrun or False, progress_reporter
|
||||
|
||||
if overlay_defaults or config:
|
||||
config_path = (
|
||||
Path(root) / config if config else resolve_config_path_with_root(root)
|
||||
)
|
||||
default_config = load_config_from_file(config_path)
|
||||
else:
|
||||
pipeline_config: str | PipelineConfig = config or _create_default_config(
|
||||
root, None, verbose, dryrun or False, progress_reporter
|
||||
try:
|
||||
config_path = resolve_config_path_with_root(root)
|
||||
default_config = load_config_from_file(config_path)
|
||||
except FileNotFoundError:
|
||||
default_config = create_graphrag_config(root_dir=root)
|
||||
|
||||
if nocache:
|
||||
default_config.cache.type = CacheType.none
|
||||
|
||||
enabled_logging, log_path = enable_logging_with_config(
|
||||
default_config, run_id, verbose
|
||||
)
|
||||
cache = NoopPipelineCache() if nocache else None
|
||||
if enabled_logging:
|
||||
info(f"Logging enabled at {log_path}", True)
|
||||
else:
|
||||
info(
|
||||
f"Logging not enabled for config {_redact(default_config.model_dump())}",
|
||||
True,
|
||||
)
|
||||
|
||||
if skip_validations:
|
||||
validate_config_names(progress_reporter, default_config)
|
||||
|
||||
info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
|
||||
info(
|
||||
f"Using default configuration: {_redact(default_config.model_dump())}",
|
||||
verbose,
|
||||
)
|
||||
|
||||
if dryrun:
|
||||
info("Dry run complete, exiting...", True)
|
||||
sys.exit(0)
|
||||
|
||||
pipeline_emit = emit.split(",") if emit else None
|
||||
encountered_errors = False
|
||||
|
||||
# Run pre-flight validation on config model values
|
||||
parameters = _read_config_parameters(root, config, progress_reporter)
|
||||
validate_config_names(progress_reporter, parameters)
|
||||
_register_signal_handlers(progress_reporter)
|
||||
|
||||
def _run_workflow_async() -> None:
|
||||
import signal
|
||||
outputs = asyncio.run(
|
||||
build_index(
|
||||
default_config,
|
||||
run_id,
|
||||
memprofile,
|
||||
progress_reporter,
|
||||
pipeline_emit,
|
||||
)
|
||||
)
|
||||
encountered_errors = any(
|
||||
output.errors and len(output.errors) > 0 for output in outputs
|
||||
)
|
||||
|
||||
def handle_signal(signum, _):
|
||||
# Handle the signal here
|
||||
progress_reporter.info(f"Received signal {signum}, exiting...")
|
||||
progress_reporter.dispose()
|
||||
for task in asyncio.all_tasks():
|
||||
task.cancel()
|
||||
progress_reporter.info("All tasks cancelled. Exiting...")
|
||||
|
||||
# Register signal handlers for SIGINT and SIGHUP
|
||||
signal.signal(signal.SIGINT, handle_signal)
|
||||
|
||||
if sys.platform != "win32":
|
||||
signal.signal(signal.SIGHUP, handle_signal)
|
||||
|
||||
async def execute():
|
||||
nonlocal encountered_errors
|
||||
async for output in run_pipeline_with_config(
|
||||
pipeline_config,
|
||||
run_id=run_id,
|
||||
memory_profile=memprofile,
|
||||
cache=cache,
|
||||
progress_reporter=progress_reporter,
|
||||
emit=(
|
||||
[TableEmitterType(e) for e in pipeline_emit]
|
||||
if pipeline_emit
|
||||
else None
|
||||
),
|
||||
is_resume_run=bool(resume),
|
||||
):
|
||||
if output.errors and len(output.errors) > 0:
|
||||
encountered_errors = True
|
||||
progress_reporter.error(output.workflow)
|
||||
else:
|
||||
progress_reporter.success(output.workflow)
|
||||
|
||||
progress_reporter.info(str(output.result))
|
||||
|
||||
if platform.system() == "Windows":
|
||||
import nest_asyncio # type: ignore Ignoring because out of windows this will cause an error
|
||||
|
||||
nest_asyncio.apply()
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(execute())
|
||||
elif sys.version_info >= (3, 11):
|
||||
import uvloop # type: ignore Ignoring because on windows this will cause an error
|
||||
|
||||
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: # type: ignore Ignoring because minor versions this will throw an error
|
||||
runner.run(execute())
|
||||
else:
|
||||
import uvloop # type: ignore Ignoring because on windows this will cause an error
|
||||
|
||||
uvloop.install()
|
||||
asyncio.run(execute())
|
||||
|
||||
_run_workflow_async()
|
||||
progress_reporter.stop()
|
||||
if encountered_errors:
|
||||
progress_reporter.error(
|
||||
"Errors occurred during the pipeline run, see logs for more details."
|
||||
error(
|
||||
"Errors occurred during the pipeline run, see logs for more details.", True
|
||||
)
|
||||
else:
|
||||
progress_reporter.success("All workflows completed successfully.")
|
||||
success("All workflows completed successfully.", True)
|
||||
|
||||
if cli:
|
||||
sys.exit(1 if encountered_errors else 0)
|
||||
|
||||
|
||||
@ -225,101 +241,3 @@ def _initialize_project_at(path: str, reporter: ProgressReporter) -> None:
|
||||
file.write(
|
||||
COMMUNITY_REPORT_PROMPT.encode(encoding="utf-8", errors="strict")
|
||||
)
|
||||
|
||||
|
||||
def _create_default_config(
|
||||
root: str,
|
||||
config: str | None,
|
||||
verbose: bool,
|
||||
dryrun: bool,
|
||||
reporter: ProgressReporter,
|
||||
) -> PipelineConfig:
|
||||
"""Overlay default values on an existing config or create a default config if none is provided."""
|
||||
if config and not Path(config).exists():
|
||||
msg = f"Configuration file {config} does not exist"
|
||||
raise ValueError
|
||||
|
||||
if not Path(root).exists():
|
||||
msg = f"Root directory {root} does not exist"
|
||||
raise ValueError(msg)
|
||||
|
||||
parameters = _read_config_parameters(root, config, reporter)
|
||||
log.info(
|
||||
"using default configuration: %s",
|
||||
redact(parameters.model_dump()),
|
||||
)
|
||||
|
||||
if verbose or dryrun:
|
||||
reporter.info(f"Using default configuration: {redact(parameters.model_dump())}")
|
||||
result = create_pipeline_config(parameters, verbose)
|
||||
if verbose or dryrun:
|
||||
reporter.info(f"Final Config: {redact(result.model_dump())}")
|
||||
|
||||
if dryrun:
|
||||
reporter.info("dry run complete, exiting...")
|
||||
sys.exit(0)
|
||||
return result
|
||||
|
||||
|
||||
def _read_config_parameters(root: str, config: str | None, reporter: ProgressReporter):
|
||||
_root = Path(root)
|
||||
settings_yaml = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix in [".yaml", ".yml"]
|
||||
else _root / "settings.yaml"
|
||||
)
|
||||
if not settings_yaml.exists():
|
||||
settings_yaml = _root / "settings.yml"
|
||||
settings_json = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix == ".json"
|
||||
else _root / "settings.json"
|
||||
)
|
||||
|
||||
if settings_yaml.exists():
|
||||
reporter.success(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open("rb") as file:
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
if settings_json.exists():
|
||||
reporter.success(f"Reading settings from {settings_json}")
|
||||
with settings_json.open("rb") as file:
|
||||
import json
|
||||
|
||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
reporter.success("Reading settings from environment variables")
|
||||
return create_graphrag_config(root_dir=root)
|
||||
|
||||
|
||||
def _get_progress_reporter(reporter_type: str | None) -> ProgressReporter:
|
||||
if reporter_type is None or reporter_type == "rich":
|
||||
return RichProgressReporter("GraphRAG Indexer ")
|
||||
if reporter_type == "print":
|
||||
return PrintProgressReporter("GraphRAG Indexer ")
|
||||
if reporter_type == "none":
|
||||
return NullProgressReporter()
|
||||
|
||||
msg = f"Invalid progress reporter type: {reporter_type}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _enable_logging(root_dir: str, run_id: str, verbose: bool) -> None:
|
||||
logging_file = (
|
||||
Path(root_dir) / "output" / run_id / "reports" / "indexing-engine.log"
|
||||
)
|
||||
logging_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging_file.touch(exist_ok=True)
|
||||
|
||||
logging.basicConfig(
|
||||
filename=str(logging_file),
|
||||
filemode="a",
|
||||
format="%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
level=logging.DEBUG if verbose else logging.INFO,
|
||||
)
|
||||
|
30
graphrag/index/progress/load_progress_reporter.py
Normal file
30
graphrag/index/progress/load_progress_reporter.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Load a progress reporter."""
|
||||
|
||||
from .rich import RichProgressReporter
|
||||
from .types import NullProgressReporter, PrintProgressReporter, ProgressReporter
|
||||
|
||||
|
||||
def load_progress_reporter(reporter_type: str = "none") -> ProgressReporter:
|
||||
"""Load a progress reporter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
reporter_type : {"rich", "print", "none"}, default=rich
|
||||
The type of progress reporter to load.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ProgressReporter
|
||||
"""
|
||||
if reporter_type == "rich":
|
||||
return RichProgressReporter("GraphRAG Indexer ")
|
||||
if reporter_type == "print":
|
||||
return PrintProgressReporter("GraphRAG Indexer ")
|
||||
if reporter_type == "none":
|
||||
return NullProgressReporter()
|
||||
|
||||
msg = f"Invalid progress reporter type: {reporter_type}"
|
||||
raise ValueError(msg)
|
@ -1,4 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Use CLI Form
|
||||
poetry run python -m graphrag.index --config ./examples/single_verb/pipeline.yml
|
Loading…
x
Reference in New Issue
Block a user