Merge from main

This commit is contained in:
Alonso Guevara 2024-09-03 16:34:52 -06:00
commit 41ea554fda
17 changed files with 176 additions and 285 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Consistent config loading. Resolves #99 and Resolves #1049"
}

View File

@ -3,6 +3,7 @@
"""The Indexing Engine default config package root."""
from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
from .create_graphrag_config import (
create_graphrag_config,
)
@ -42,6 +43,8 @@ from .input_models import (
TextEmbeddingConfigInput,
UmapConfigInput,
)
from .load_config import load_config
from .logging import enable_logging_with_config
from .models import (
CacheConfig,
ChunkingConfig,
@ -65,6 +68,7 @@ from .models import (
UmapConfig,
)
from .read_dotenv import read_dotenv
from .resolve_timestamp_path import resolve_timestamp_path
__all__ = [
"ApiKeyMissingError",
@ -119,5 +123,10 @@ __all__ = [
"UmapConfig",
"UmapConfigInput",
"create_graphrag_config",
"enable_logging_with_config",
"load_config",
"load_config_from_file",
"read_dotenv",
"resolve_timestamp_path",
"search_for_config_in_root_dir",
]

View File

@ -9,13 +9,13 @@ from pathlib import Path
import yaml
from . import create_graphrag_config
from .create_graphrag_config import create_graphrag_config
from .models.graph_rag_config import GraphRagConfig
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
def resolve_config_path_with_root(root: str | Path) -> Path:
def search_for_config_in_root_dir(root: str | Path) -> Path | None:
"""Resolve the config path from the given root directory.
Parameters
@ -26,13 +26,9 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
Returns
-------
Path
The resolved config file path.
Raises
------
FileNotFoundError
If the config file is not found or cannot be resolved for the directory.
Path | None
returns a Path if there is a config in the root directory
Otherwise returns None.
"""
root = Path(root)
@ -44,8 +40,7 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
if (root / file).is_file():
return root / file
msg = f"Unable to resolve config file for parent directory: {root}"
raise FileNotFoundError(msg)
return None
class ConfigFileLoader(ABC):

View File

@ -0,0 +1,65 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Default method for loading config."""
from pathlib import Path
from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
from .create_graphrag_config import create_graphrag_config
from .models.graph_rag_config import GraphRagConfig
from .resolve_timestamp_path import resolve_timestamp_path
def load_config(
root_dir: str | Path,
config_filepath: str | Path | None = None,
run_id: str | None = None,
) -> GraphRagConfig:
"""Load configuration from a file or create a default configuration.
If a config file is not found the default configuration is created.
Parameters
----------
root_dir : str | Path
The root directory of the project. Will search for the config file in this directory.
config_filepath : str | Path | None
The path to the config file.
If None, searches for config file in root and
if not found creates a default configuration.
run_id : str | None
The run id to use for resolving timestamp_paths.
"""
root = Path(root_dir).resolve()
# If user specified a config file path then it is required
if config_filepath:
config_path = (root / config_filepath).resolve()
if not config_path.exists():
msg = f"Specified Config file not found: {config_path}"
raise FileNotFoundError(msg)
# Else optional resolve the config path from the root directory if it exists
config_path = search_for_config_in_root_dir(root)
if config_path:
config = load_config_from_file(config_path)
else:
config = create_graphrag_config(root_dir=str(root))
if run_id:
config.storage.base_dir = str(
resolve_timestamp_path((root / config.storage.base_dir).resolve(), run_id)
)
config.reporting.base_dir = str(
resolve_timestamp_path((root / config.reporting.base_dir).resolve(), run_id)
)
else:
config.storage.base_dir = str(
resolve_timestamp_path((root / config.storage.base_dir).resolve())
)
config.reporting.base_dir = str(
resolve_timestamp_path((root / config.reporting.base_dir).resolve())
)
return config

View File

@ -8,7 +8,6 @@ from pathlib import Path
from .enums import ReportingType
from .models.graph_rag_config import GraphRagConfig
from .resolve_timestamp_path import resolve_timestamp_path
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
@ -35,7 +34,7 @@ def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
def enable_logging_with_config(
config: GraphRagConfig, timestamp_value: str, verbose: bool = False
config: GraphRagConfig, verbose: bool = False
) -> tuple[bool, str]:
"""Enable logging to a file based on the config.
@ -56,10 +55,7 @@ def enable_logging_with_config(
(True, str) if logging was enabled.
"""
if config.reporting.type == ReportingType.file:
log_path = resolve_timestamp_path(
Path(config.root_dir) / config.reporting.base_dir / "indexing-engine.log",
timestamp_value,
)
log_path = Path(config.reporting.base_dir) / "indexing-engine.log"
enable_logging(log_path, verbose)
return (True, str(log_path))
return (False, "")

View File

@ -63,11 +63,6 @@ if __name__ == "__main__":
help="Create an initial configuration in the given path.",
action="store_true",
)
parser.add_argument(
"--overlay-defaults",
help="Overlay default configuration values on a provided configuration file (--config).",
action="store_true",
)
parser.add_argument(
"--skip-validations",
help="Skip any preflight validation. Useful when running no LLM steps.",
@ -83,21 +78,17 @@ if __name__ == "__main__":
)
args = parser.parse_args()
if args.overlay_defaults and not args.config:
parser.error("--overlay-defaults requires --config")
index_cli(
root=args.root,
root_dir=args.root,
verbose=args.verbose or False,
resume=args.resume,
update_index_id=args.update_index,
memprofile=args.memprofile or False,
nocache=args.nocache or False,
reporter=args.reporter,
config=args.config,
config_filepath=args.config,
emit=args.emit,
dryrun=args.dryrun or False,
init=args.init or False,
overlay_defaults=args.overlay_defaults or False,
skip_validations=args.skip_validations or False,
)

View File

@ -8,9 +8,9 @@ WARNING: This API is under development and may undergo changes in future release
Backwards compatibility is not guaranteed at this time.
"""
from graphrag.config.enums import CacheType
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from pathlib import Path
from graphrag.config import CacheType, GraphRagConfig
from .cache.noop_pipeline_cache import NoopPipelineCache
from .create_pipeline_config import create_pipeline_config
@ -50,11 +50,7 @@ async def build_index(
list[PipelineRunResult]
The list of pipeline run results
"""
try:
resolve_timestamp_path(config.storage.base_dir, run_id)
resume = True
except ValueError as _:
resume = False
resume = Path(config.storage.base_dir).exists()
pipeline_config = create_pipeline_config(config)
pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None

View File

@ -11,13 +11,7 @@ import time
import warnings
from pathlib import Path
from graphrag.config import create_graphrag_config
from graphrag.config.config_file_loader import (
load_config_from_file,
resolve_config_path_with_root,
)
from graphrag.config.enums import CacheType
from graphrag.config.logging import enable_logging_with_config
from graphrag.config import CacheType, enable_logging_with_config, load_config
from .api import build_index, update_index
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
@ -103,7 +97,7 @@ def _register_signal_handlers(reporter: ProgressReporter):
def index_cli(
root: str,
root_dir: str,
init: bool,
verbose: bool,
resume: str | None,
@ -111,10 +105,9 @@ def index_cli(
memprofile: bool,
nocache: bool,
reporter: str | None,
config: str | None,
config_filepath: str | None,
emit: str | None,
dryrun: bool,
overlay_defaults: bool,
skip_validations: bool,
):
"""Run the pipeline with the given config."""
@ -123,41 +116,30 @@ def index_cli(
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
if init:
_initialize_project_at(root, progress_reporter)
_initialize_project_at(root_dir, progress_reporter)
sys.exit(0)
if overlay_defaults or config:
config_path = (
Path(root) / config if config else resolve_config_path_with_root(root)
)
default_config = load_config_from_file(config_path)
else:
try:
config_path = resolve_config_path_with_root(root)
default_config = load_config_from_file(config_path)
except FileNotFoundError:
default_config = create_graphrag_config(root_dir=root)
root = Path(root_dir).resolve()
config = load_config(root, config_filepath, run_id)
if nocache:
default_config.cache.type = CacheType.none
config.cache.type = CacheType.none
enabled_logging, log_path = enable_logging_with_config(
default_config, run_id, verbose
)
enabled_logging, log_path = enable_logging_with_config(config, verbose)
if enabled_logging:
info(f"Logging enabled at {log_path}", True)
else:
info(
f"Logging not enabled for config {_redact(default_config.model_dump())}",
f"Logging not enabled for config {_redact(config.model_dump())}",
True,
)
if skip_validations:
validate_config_names(progress_reporter, default_config)
validate_config_names(progress_reporter, config)
info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
info(
f"Using default configuration: {_redact(default_config.model_dump())}",
f"Using default configuration: {_redact(config.model_dump())}",
verbose,
)
@ -169,26 +151,15 @@ def index_cli(
_register_signal_handlers(progress_reporter)
if update_index_id:
outputs = asyncio.run(
update_index(
default_config,
memprofile,
update_index_id,
progress_reporter,
pipeline_emit,
)
)
else:
outputs = asyncio.run(
build_index(
default_config,
run_id,
memprofile,
progress_reporter,
pipeline_emit,
)
outputs = asyncio.run(
build_index(
config,
run_id,
memprofile,
progress_reporter,
pipeline_emit,
)
)
encountered_errors = any(
output.errors and len(output.errors) > 0 for output in outputs
)

View File

@ -5,11 +5,11 @@
from pathlib import Path
from graphrag.config import load_config
from graphrag.index.progress import PrintProgressReporter
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import (
MIN_CHUNK_SIZE,
read_config_parameters,
)
from . import api
@ -53,11 +53,12 @@ async def prompt_tune(
- min_examples_required: The minimum number of examples required for entity extraction prompts.
"""
reporter = PrintProgressReporter("")
graph_config = read_config_parameters(root, reporter, config)
root_path = Path(root).resolve()
graph_config = load_config(root_path, config)
prompts = await api.generate_indexing_prompts(
config=graph_config,
root=root,
root=str(root_path),
chunk_size=chunk_size,
limit=limit,
selection_method=selection_method,
@ -70,7 +71,7 @@ async def prompt_tune(
k=k,
)
output_path = Path(output)
output_path = (root_path / output).resolve()
if output_path:
reporter.info(f"Writing prompts to {output_path}")
output_path.mkdir(parents=True, exist_ok=True)

View File

@ -3,12 +3,10 @@
"""Fine-tuning config and data loader module."""
from .config import read_config_parameters
from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks
__all__ = [
"MIN_CHUNK_OVERLAP",
"MIN_CHUNK_SIZE",
"load_docs_in_chunks",
"read_config_parameters",
]

View File

@ -1,61 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Config loading, parsing and handling module."""
from pathlib import Path
from graphrag.config import create_graphrag_config
from graphrag.index.progress.types import ProgressReporter
def read_config_parameters(
root: str, reporter: ProgressReporter, config: str | None = None
):
"""Read the configuration parameters from the settings file or environment variables.
Parameters
----------
- root: The root directory where the parameters are.
- reporter: The progress reporter.
- config: The path to the settings file.
"""
_root = Path(root)
settings_yaml = (
Path(config)
if config and Path(config).suffix in [".yaml", ".yml"]
else _root / "settings.yaml"
)
if not settings_yaml.exists():
settings_yaml = _root / "settings.yml"
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("rb") as file:
import yaml
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
settings_json = (
Path(config)
if config and Path(config).suffix == ".json"
else _root / "settings.json"
)
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("rb") as file:
import yaml
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("rb") as file:
import json
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
reporter.info("Reading settings from environment variables")
return create_graphrag_config(root_dir=root)

View File

@ -24,8 +24,7 @@ from typing import Any
import pandas as pd
from pydantic import validate_call
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from graphrag.config import GraphRagConfig
from graphrag.index.progress.types import PrintProgressReporter
from graphrag.model.entity import Entity
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
@ -149,7 +148,6 @@ async def global_search_streaming(
@validate_call(config={"arbitrary_types_allowed": True})
async def local_search(
root_dir: str | None,
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
@ -196,9 +194,8 @@ async def local_search(
_entities = read_indexer_entities(nodes, entities, community_level)
base_dir = Path(str(root_dir)) / config.storage.base_dir
resolved_base_dir = resolve_timestamp_path(base_dir)
lancedb_dir = resolved_base_dir / "lancedb"
lancedb_dir = Path(config.storage.base_dir) / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = _get_embedding_description_store(
entities=_entities,
@ -227,7 +224,6 @@ async def local_search(
@validate_call(config={"arbitrary_types_allowed": True})
async def local_search_streaming(
root_dir: str | None,
config: GraphRagConfig,
nodes: pd.DataFrame,
entities: pd.DataFrame,
@ -271,9 +267,8 @@ async def local_search_streaming(
_entities = read_indexer_entities(nodes, entities, community_level)
base_dir = Path(str(root_dir)) / config.storage.base_dir
resolved_base_dir = resolve_timestamp_path(base_dir)
lancedb_dir = resolved_base_dir / "lancedb"
lancedb_dir = lancedb_dir = Path(config.storage.base_dir) / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = _get_embedding_description_store(
entities=_entities,

View File

@ -4,17 +4,12 @@
"""Command line interface for the query module."""
import asyncio
import re
import sys
from pathlib import Path
import pandas as pd
from graphrag.config import (
GraphRagConfig,
create_graphrag_config,
)
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from graphrag.config import load_config, resolve_timestamp_path
from graphrag.index.progress import PrintProgressReporter
from . import api
@ -25,7 +20,7 @@ reporter = PrintProgressReporter("")
def run_global_search(
config_filepath: str | None,
data_dir: str | None,
root_dir: str | None,
root_dir: str,
community_level: int,
response_type: str,
streaming: bool,
@ -35,10 +30,15 @@ def run_global_search(
Loads index files required for global search and calls the Query API.
"""
data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_filepath
)
data_path = Path(data_dir)
root = Path(root_dir).resolve()
config = load_config(root, config_filepath)
if data_dir:
config.storage.base_dir = str(
resolve_timestamp_path((root / data_dir).resolve())
)
data_path = Path(config.storage.base_dir).resolve()
final_nodes: pd.DataFrame = pd.read_parquet(
data_path / "create_final_nodes.parquet"
@ -98,7 +98,7 @@ def run_global_search(
def run_local_search(
config_filepath: str | None,
data_dir: str | None,
root_dir: str | None,
root_dir: str,
community_level: int,
response_type: str,
streaming: bool,
@ -108,10 +108,15 @@ def run_local_search(
Loads index files required for local search and calls the Query API.
"""
data_dir, root_dir, config = _configure_paths_and_settings(
data_dir, root_dir, config_filepath
)
data_path = Path(data_dir)
root = Path(root_dir).resolve()
config = load_config(root, config_filepath)
if data_dir:
config.storage.base_dir = str(
resolve_timestamp_path((root / data_dir).resolve())
)
data_path = Path(config.storage.base_dir).resolve()
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
final_community_reports = pd.read_parquet(
@ -137,7 +142,6 @@ def run_local_search(
context_data = None
get_context_data = True
async for stream_chunk in api.local_search_streaming(
root_dir=root_dir,
config=config,
nodes=final_nodes,
entities=final_entities,
@ -163,7 +167,6 @@ def run_local_search(
# not streaming
response, context_data = asyncio.run(
api.local_search(
root_dir=root_dir,
config=config,
nodes=final_nodes,
entities=final_entities,
@ -180,77 +183,3 @@ def run_local_search(
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data.
return response, context_data
def _configure_paths_and_settings(
data_dir: str | None,
root_dir: str | None,
config_filepath: str | None,
) -> tuple[str, str | None, GraphRagConfig]:
config = _create_graphrag_config(root_dir, config_filepath)
if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg)
if data_dir is None:
base_dir = Path(str(root_dir)) / config.storage.base_dir
data_dir = str(resolve_timestamp_path(base_dir))
return data_dir, root_dir, config
def _infer_data_dir(root: str) -> str:
output = Path(root) / "output"
# use the latest data-run folder
if output.exists():
expr = re.compile(r"\d{8}-\d{6}")
filtered = [f for f in output.iterdir() if f.is_dir() and expr.match(f.name)]
folders = sorted(filtered, key=lambda f: f.name, reverse=True)
if len(folders) > 0:
folder = folders[0]
return str((folder / "artifacts").absolute())
msg = f"Could not infer data directory from root={root}"
raise ValueError(msg)
def _create_graphrag_config(
root: str | None,
config_filepath: str | None,
) -> GraphRagConfig:
"""Create a GraphRag configuration."""
return _read_config_parameters(root or "./", config_filepath)
def _read_config_parameters(root: str, config: str | None):
_root = Path(root)
settings_yaml = (
Path(config)
if config and Path(config).suffix in [".yaml", ".yml"]
else _root / "settings.yaml"
)
if not settings_yaml.exists():
settings_yaml = _root / "settings.yml"
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open(
"rb",
) as file:
import yaml
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
settings_json = (
Path(config)
if config and Path(config).suffix == ".json"
else _root / "settings.json"
)
if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("rb") as file:
import json
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
reporter.info("Reading settings from environment variables")
return create_graphrag_config(root_dir=root)

View File

@ -123,6 +123,7 @@ _convert_global_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/
_semversioner_release = "semversioner release"
_semversioner_changelog = "semversioner changelog > CHANGELOG.md"
_semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)"
semversioner_add = "semversioner add-change"
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
check_format = 'ruff format . --check --preview'
fix = "ruff --preview check --fix ."

View File

@ -0,0 +1,33 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from pathlib import Path
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
def test_resolve_timestamp_path_no_timestamp_with_run_id():
path = Path("path/to/data")
result = resolve_timestamp_path(path, "20240812-121000")
assert result == path
def test_resolve_timestamp_path_no_timestamp_without_run_id():
path = Path("path/to/data")
result = resolve_timestamp_path(path)
assert result == path
def test_resolve_timestamp_path_with_timestamp_and_run_id():
path = Path("some/path/${timestamp}/data")
expected = Path("some/path/20240812/data")
result = resolve_timestamp_path(path, "20240812")
assert result == expected
def test_resolve_timestamp_path_with_timestamp_and_inferred_directory():
cwd = Path(__file__).parent
path = cwd / "fixtures/timestamp_dirs/${timestamp}/data"
expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data"
result = resolve_timestamp_path(path)
assert result == expected

View File

@ -1,32 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from pathlib import Path
import pytest
from graphrag.query.cli import _infer_data_dir
def test_infer_data_dir():
root = "./tests/unit/query/data/defaults"
result = Path(_infer_data_dir(root))
assert result.parts[-2] == "20240812-121000"
def test_infer_data_dir_ignores_hidden_files():
"""A hidden file, starting with '.', will naturally be selected as latest data directory."""
root = "./tests/unit/query/data/hidden"
result = Path(_infer_data_dir(root))
assert result.parts[-2] == "20240812-121000"
def test_infer_data_dir_ignores_non_numeric():
root = "./tests/unit/query/data/non-numeric"
result = Path(_infer_data_dir(root))
assert result.parts[-2] == "20240812-121000"
def test_infer_data_dir_throws_on_no_match():
root = "./tests/unit/query/data/empty"
with pytest.raises(ValueError): # noqa PT011 (this is what is actually thrown...)
_infer_data_dir(root)