mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-11 15:01:33 +00:00
Merge from main
This commit is contained in:
commit
41ea554fda
@ -0,0 +1,4 @@
|
||||
{
|
||||
"type": "patch",
|
||||
"description": "Consistent config loading. Resolves #99 and Resolves #1049"
|
||||
}
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
"""The Indexing Engine default config package root."""
|
||||
|
||||
from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
|
||||
from .create_graphrag_config import (
|
||||
create_graphrag_config,
|
||||
)
|
||||
@ -42,6 +43,8 @@ from .input_models import (
|
||||
TextEmbeddingConfigInput,
|
||||
UmapConfigInput,
|
||||
)
|
||||
from .load_config import load_config
|
||||
from .logging import enable_logging_with_config
|
||||
from .models import (
|
||||
CacheConfig,
|
||||
ChunkingConfig,
|
||||
@ -65,6 +68,7 @@ from .models import (
|
||||
UmapConfig,
|
||||
)
|
||||
from .read_dotenv import read_dotenv
|
||||
from .resolve_timestamp_path import resolve_timestamp_path
|
||||
|
||||
__all__ = [
|
||||
"ApiKeyMissingError",
|
||||
@ -119,5 +123,10 @@ __all__ = [
|
||||
"UmapConfig",
|
||||
"UmapConfigInput",
|
||||
"create_graphrag_config",
|
||||
"enable_logging_with_config",
|
||||
"load_config",
|
||||
"load_config_from_file",
|
||||
"read_dotenv",
|
||||
"resolve_timestamp_path",
|
||||
"search_for_config_in_root_dir",
|
||||
]
|
||||
|
||||
@ -9,13 +9,13 @@ from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
from . import create_graphrag_config
|
||||
from .create_graphrag_config import create_graphrag_config
|
||||
from .models.graph_rag_config import GraphRagConfig
|
||||
|
||||
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
|
||||
|
||||
|
||||
def resolve_config_path_with_root(root: str | Path) -> Path:
|
||||
def search_for_config_in_root_dir(root: str | Path) -> Path | None:
|
||||
"""Resolve the config path from the given root directory.
|
||||
|
||||
Parameters
|
||||
@ -26,13 +26,9 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
The resolved config file path.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file is not found or cannot be resolved for the directory.
|
||||
Path | None
|
||||
returns a Path if there is a config in the root directory
|
||||
Otherwise returns None.
|
||||
"""
|
||||
root = Path(root)
|
||||
|
||||
@ -44,8 +40,7 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
|
||||
if (root / file).is_file():
|
||||
return root / file
|
||||
|
||||
msg = f"Unable to resolve config file for parent directory: {root}"
|
||||
raise FileNotFoundError(msg)
|
||||
return None
|
||||
|
||||
|
||||
class ConfigFileLoader(ABC):
|
||||
|
||||
65
graphrag/config/load_config.py
Normal file
65
graphrag/config/load_config.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Default method for loading config."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
|
||||
from .create_graphrag_config import create_graphrag_config
|
||||
from .models.graph_rag_config import GraphRagConfig
|
||||
from .resolve_timestamp_path import resolve_timestamp_path
|
||||
|
||||
|
||||
def load_config(
|
||||
root_dir: str | Path,
|
||||
config_filepath: str | Path | None = None,
|
||||
run_id: str | None = None,
|
||||
) -> GraphRagConfig:
|
||||
"""Load configuration from a file or create a default configuration.
|
||||
|
||||
If a config file is not found the default configuration is created.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root_dir : str | Path
|
||||
The root directory of the project. Will search for the config file in this directory.
|
||||
config_filepath : str | Path | None
|
||||
The path to the config file.
|
||||
If None, searches for config file in root and
|
||||
if not found creates a default configuration.
|
||||
run_id : str | None
|
||||
The run id to use for resolving timestamp_paths.
|
||||
"""
|
||||
root = Path(root_dir).resolve()
|
||||
|
||||
# If user specified a config file path then it is required
|
||||
if config_filepath:
|
||||
config_path = (root / config_filepath).resolve()
|
||||
if not config_path.exists():
|
||||
msg = f"Specified Config file not found: {config_path}"
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Else optional resolve the config path from the root directory if it exists
|
||||
config_path = search_for_config_in_root_dir(root)
|
||||
if config_path:
|
||||
config = load_config_from_file(config_path)
|
||||
else:
|
||||
config = create_graphrag_config(root_dir=str(root))
|
||||
|
||||
if run_id:
|
||||
config.storage.base_dir = str(
|
||||
resolve_timestamp_path((root / config.storage.base_dir).resolve(), run_id)
|
||||
)
|
||||
config.reporting.base_dir = str(
|
||||
resolve_timestamp_path((root / config.reporting.base_dir).resolve(), run_id)
|
||||
)
|
||||
else:
|
||||
config.storage.base_dir = str(
|
||||
resolve_timestamp_path((root / config.storage.base_dir).resolve())
|
||||
)
|
||||
config.reporting.base_dir = str(
|
||||
resolve_timestamp_path((root / config.reporting.base_dir).resolve())
|
||||
)
|
||||
|
||||
return config
|
||||
@ -8,7 +8,6 @@ from pathlib import Path
|
||||
|
||||
from .enums import ReportingType
|
||||
from .models.graph_rag_config import GraphRagConfig
|
||||
from .resolve_timestamp_path import resolve_timestamp_path
|
||||
|
||||
|
||||
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
|
||||
@ -35,7 +34,7 @@ def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
|
||||
|
||||
|
||||
def enable_logging_with_config(
|
||||
config: GraphRagConfig, timestamp_value: str, verbose: bool = False
|
||||
config: GraphRagConfig, verbose: bool = False
|
||||
) -> tuple[bool, str]:
|
||||
"""Enable logging to a file based on the config.
|
||||
|
||||
@ -56,10 +55,7 @@ def enable_logging_with_config(
|
||||
(True, str) if logging was enabled.
|
||||
"""
|
||||
if config.reporting.type == ReportingType.file:
|
||||
log_path = resolve_timestamp_path(
|
||||
Path(config.root_dir) / config.reporting.base_dir / "indexing-engine.log",
|
||||
timestamp_value,
|
||||
)
|
||||
log_path = Path(config.reporting.base_dir) / "indexing-engine.log"
|
||||
enable_logging(log_path, verbose)
|
||||
return (True, str(log_path))
|
||||
return (False, "")
|
||||
|
||||
@ -63,11 +63,6 @@ if __name__ == "__main__":
|
||||
help="Create an initial configuration in the given path.",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overlay-defaults",
|
||||
help="Overlay default configuration values on a provided configuration file (--config).",
|
||||
action="store_true",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-validations",
|
||||
help="Skip any preflight validation. Useful when running no LLM steps.",
|
||||
@ -83,21 +78,17 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.overlay_defaults and not args.config:
|
||||
parser.error("--overlay-defaults requires --config")
|
||||
|
||||
index_cli(
|
||||
root=args.root,
|
||||
root_dir=args.root,
|
||||
verbose=args.verbose or False,
|
||||
resume=args.resume,
|
||||
update_index_id=args.update_index,
|
||||
memprofile=args.memprofile or False,
|
||||
nocache=args.nocache or False,
|
||||
reporter=args.reporter,
|
||||
config=args.config,
|
||||
config_filepath=args.config,
|
||||
emit=args.emit,
|
||||
dryrun=args.dryrun or False,
|
||||
init=args.init or False,
|
||||
overlay_defaults=args.overlay_defaults or False,
|
||||
skip_validations=args.skip_validations or False,
|
||||
)
|
||||
|
||||
@ -8,9 +8,9 @@ WARNING: This API is under development and may undergo changes in future release
|
||||
Backwards compatibility is not guaranteed at this time.
|
||||
"""
|
||||
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config import CacheType, GraphRagConfig
|
||||
|
||||
from .cache.noop_pipeline_cache import NoopPipelineCache
|
||||
from .create_pipeline_config import create_pipeline_config
|
||||
@ -50,11 +50,7 @@ async def build_index(
|
||||
list[PipelineRunResult]
|
||||
The list of pipeline run results
|
||||
"""
|
||||
try:
|
||||
resolve_timestamp_path(config.storage.base_dir, run_id)
|
||||
resume = True
|
||||
except ValueError as _:
|
||||
resume = False
|
||||
resume = Path(config.storage.base_dir).exists()
|
||||
pipeline_config = create_pipeline_config(config)
|
||||
pipeline_cache = (
|
||||
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
|
||||
|
||||
@ -11,13 +11,7 @@ import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config import create_graphrag_config
|
||||
from graphrag.config.config_file_loader import (
|
||||
load_config_from_file,
|
||||
resolve_config_path_with_root,
|
||||
)
|
||||
from graphrag.config.enums import CacheType
|
||||
from graphrag.config.logging import enable_logging_with_config
|
||||
from graphrag.config import CacheType, enable_logging_with_config, load_config
|
||||
|
||||
from .api import build_index, update_index
|
||||
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
|
||||
@ -103,7 +97,7 @@ def _register_signal_handlers(reporter: ProgressReporter):
|
||||
|
||||
|
||||
def index_cli(
|
||||
root: str,
|
||||
root_dir: str,
|
||||
init: bool,
|
||||
verbose: bool,
|
||||
resume: str | None,
|
||||
@ -111,10 +105,9 @@ def index_cli(
|
||||
memprofile: bool,
|
||||
nocache: bool,
|
||||
reporter: str | None,
|
||||
config: str | None,
|
||||
config_filepath: str | None,
|
||||
emit: str | None,
|
||||
dryrun: bool,
|
||||
overlay_defaults: bool,
|
||||
skip_validations: bool,
|
||||
):
|
||||
"""Run the pipeline with the given config."""
|
||||
@ -123,41 +116,30 @@ def index_cli(
|
||||
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
|
||||
|
||||
if init:
|
||||
_initialize_project_at(root, progress_reporter)
|
||||
_initialize_project_at(root_dir, progress_reporter)
|
||||
sys.exit(0)
|
||||
|
||||
if overlay_defaults or config:
|
||||
config_path = (
|
||||
Path(root) / config if config else resolve_config_path_with_root(root)
|
||||
)
|
||||
default_config = load_config_from_file(config_path)
|
||||
else:
|
||||
try:
|
||||
config_path = resolve_config_path_with_root(root)
|
||||
default_config = load_config_from_file(config_path)
|
||||
except FileNotFoundError:
|
||||
default_config = create_graphrag_config(root_dir=root)
|
||||
root = Path(root_dir).resolve()
|
||||
config = load_config(root, config_filepath, run_id)
|
||||
|
||||
if nocache:
|
||||
default_config.cache.type = CacheType.none
|
||||
config.cache.type = CacheType.none
|
||||
|
||||
enabled_logging, log_path = enable_logging_with_config(
|
||||
default_config, run_id, verbose
|
||||
)
|
||||
enabled_logging, log_path = enable_logging_with_config(config, verbose)
|
||||
if enabled_logging:
|
||||
info(f"Logging enabled at {log_path}", True)
|
||||
else:
|
||||
info(
|
||||
f"Logging not enabled for config {_redact(default_config.model_dump())}",
|
||||
f"Logging not enabled for config {_redact(config.model_dump())}",
|
||||
True,
|
||||
)
|
||||
|
||||
if skip_validations:
|
||||
validate_config_names(progress_reporter, default_config)
|
||||
validate_config_names(progress_reporter, config)
|
||||
|
||||
info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
|
||||
info(
|
||||
f"Using default configuration: {_redact(default_config.model_dump())}",
|
||||
f"Using default configuration: {_redact(config.model_dump())}",
|
||||
verbose,
|
||||
)
|
||||
|
||||
@ -169,26 +151,15 @@ def index_cli(
|
||||
|
||||
_register_signal_handlers(progress_reporter)
|
||||
|
||||
if update_index_id:
|
||||
outputs = asyncio.run(
|
||||
update_index(
|
||||
default_config,
|
||||
memprofile,
|
||||
update_index_id,
|
||||
progress_reporter,
|
||||
pipeline_emit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
outputs = asyncio.run(
|
||||
build_index(
|
||||
default_config,
|
||||
run_id,
|
||||
memprofile,
|
||||
progress_reporter,
|
||||
pipeline_emit,
|
||||
)
|
||||
outputs = asyncio.run(
|
||||
build_index(
|
||||
config,
|
||||
run_id,
|
||||
memprofile,
|
||||
progress_reporter,
|
||||
pipeline_emit,
|
||||
)
|
||||
)
|
||||
encountered_errors = any(
|
||||
output.errors and len(output.errors) > 0 for output in outputs
|
||||
)
|
||||
|
||||
@ -5,11 +5,11 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config import load_config
|
||||
from graphrag.index.progress import PrintProgressReporter
|
||||
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
|
||||
from graphrag.prompt_tune.loader import (
|
||||
MIN_CHUNK_SIZE,
|
||||
read_config_parameters,
|
||||
)
|
||||
|
||||
from . import api
|
||||
@ -53,11 +53,12 @@ async def prompt_tune(
|
||||
- min_examples_required: The minimum number of examples required for entity extraction prompts.
|
||||
"""
|
||||
reporter = PrintProgressReporter("")
|
||||
graph_config = read_config_parameters(root, reporter, config)
|
||||
root_path = Path(root).resolve()
|
||||
graph_config = load_config(root_path, config)
|
||||
|
||||
prompts = await api.generate_indexing_prompts(
|
||||
config=graph_config,
|
||||
root=root,
|
||||
root=str(root_path),
|
||||
chunk_size=chunk_size,
|
||||
limit=limit,
|
||||
selection_method=selection_method,
|
||||
@ -70,7 +71,7 @@ async def prompt_tune(
|
||||
k=k,
|
||||
)
|
||||
|
||||
output_path = Path(output)
|
||||
output_path = (root_path / output).resolve()
|
||||
if output_path:
|
||||
reporter.info(f"Writing prompts to {output_path}")
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -3,12 +3,10 @@
|
||||
|
||||
"""Fine-tuning config and data loader module."""
|
||||
|
||||
from .config import read_config_parameters
|
||||
from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks
|
||||
|
||||
__all__ = [
|
||||
"MIN_CHUNK_OVERLAP",
|
||||
"MIN_CHUNK_SIZE",
|
||||
"load_docs_in_chunks",
|
||||
"read_config_parameters",
|
||||
]
|
||||
|
||||
@ -1,61 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
"""Config loading, parsing and handling module."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config import create_graphrag_config
|
||||
from graphrag.index.progress.types import ProgressReporter
|
||||
|
||||
|
||||
def read_config_parameters(
|
||||
root: str, reporter: ProgressReporter, config: str | None = None
|
||||
):
|
||||
"""Read the configuration parameters from the settings file or environment variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
- root: The root directory where the parameters are.
|
||||
- reporter: The progress reporter.
|
||||
- config: The path to the settings file.
|
||||
"""
|
||||
_root = Path(root)
|
||||
settings_yaml = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix in [".yaml", ".yml"]
|
||||
else _root / "settings.yaml"
|
||||
)
|
||||
if not settings_yaml.exists():
|
||||
settings_yaml = _root / "settings.yml"
|
||||
if settings_yaml.exists():
|
||||
reporter.info(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open("rb") as file:
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
settings_json = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix == ".json"
|
||||
else _root / "settings.json"
|
||||
)
|
||||
if settings_yaml.exists():
|
||||
reporter.info(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open("rb") as file:
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
if settings_json.exists():
|
||||
reporter.info(f"Reading settings from {settings_json}")
|
||||
with settings_json.open("rb") as file:
|
||||
import json
|
||||
|
||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
reporter.info("Reading settings from environment variables")
|
||||
return create_graphrag_config(root_dir=root)
|
||||
@ -24,8 +24,7 @@ from typing import Any
|
||||
import pandas as pd
|
||||
from pydantic import validate_call
|
||||
|
||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
||||
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
||||
from graphrag.config import GraphRagConfig
|
||||
from graphrag.index.progress.types import PrintProgressReporter
|
||||
from graphrag.model.entity import Entity
|
||||
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
|
||||
@ -149,7 +148,6 @@ async def global_search_streaming(
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def local_search(
|
||||
root_dir: str | None,
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
@ -196,9 +194,8 @@ async def local_search(
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
|
||||
base_dir = Path(str(root_dir)) / config.storage.base_dir
|
||||
resolved_base_dir = resolve_timestamp_path(base_dir)
|
||||
lancedb_dir = resolved_base_dir / "lancedb"
|
||||
lancedb_dir = Path(config.storage.base_dir) / "lancedb"
|
||||
|
||||
vector_store_args.update({"db_uri": str(lancedb_dir)})
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
entities=_entities,
|
||||
@ -227,7 +224,6 @@ async def local_search(
|
||||
|
||||
@validate_call(config={"arbitrary_types_allowed": True})
|
||||
async def local_search_streaming(
|
||||
root_dir: str | None,
|
||||
config: GraphRagConfig,
|
||||
nodes: pd.DataFrame,
|
||||
entities: pd.DataFrame,
|
||||
@ -271,9 +267,8 @@ async def local_search_streaming(
|
||||
|
||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||
|
||||
base_dir = Path(str(root_dir)) / config.storage.base_dir
|
||||
resolved_base_dir = resolve_timestamp_path(base_dir)
|
||||
lancedb_dir = resolved_base_dir / "lancedb"
|
||||
lancedb_dir = lancedb_dir = Path(config.storage.base_dir) / "lancedb"
|
||||
|
||||
vector_store_args.update({"db_uri": str(lancedb_dir)})
|
||||
description_embedding_store = _get_embedding_description_store(
|
||||
entities=_entities,
|
||||
|
||||
@ -4,17 +4,12 @@
|
||||
"""Command line interface for the query module."""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from graphrag.config import (
|
||||
GraphRagConfig,
|
||||
create_graphrag_config,
|
||||
)
|
||||
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
||||
from graphrag.config import load_config, resolve_timestamp_path
|
||||
from graphrag.index.progress import PrintProgressReporter
|
||||
|
||||
from . import api
|
||||
@ -25,7 +20,7 @@ reporter = PrintProgressReporter("")
|
||||
def run_global_search(
|
||||
config_filepath: str | None,
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
root_dir: str,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
@ -35,10 +30,15 @@ def run_global_search(
|
||||
|
||||
Loads index files required for global search and calls the Query API.
|
||||
"""
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
||||
data_dir, root_dir, config_filepath
|
||||
)
|
||||
data_path = Path(data_dir)
|
||||
root = Path(root_dir).resolve()
|
||||
config = load_config(root, config_filepath)
|
||||
|
||||
if data_dir:
|
||||
config.storage.base_dir = str(
|
||||
resolve_timestamp_path((root / data_dir).resolve())
|
||||
)
|
||||
|
||||
data_path = Path(config.storage.base_dir).resolve()
|
||||
|
||||
final_nodes: pd.DataFrame = pd.read_parquet(
|
||||
data_path / "create_final_nodes.parquet"
|
||||
@ -98,7 +98,7 @@ def run_global_search(
|
||||
def run_local_search(
|
||||
config_filepath: str | None,
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
root_dir: str,
|
||||
community_level: int,
|
||||
response_type: str,
|
||||
streaming: bool,
|
||||
@ -108,10 +108,15 @@ def run_local_search(
|
||||
|
||||
Loads index files required for local search and calls the Query API.
|
||||
"""
|
||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
||||
data_dir, root_dir, config_filepath
|
||||
)
|
||||
data_path = Path(data_dir)
|
||||
root = Path(root_dir).resolve()
|
||||
config = load_config(root, config_filepath)
|
||||
|
||||
if data_dir:
|
||||
config.storage.base_dir = str(
|
||||
resolve_timestamp_path((root / data_dir).resolve())
|
||||
)
|
||||
|
||||
data_path = Path(config.storage.base_dir).resolve()
|
||||
|
||||
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
|
||||
final_community_reports = pd.read_parquet(
|
||||
@ -137,7 +142,6 @@ def run_local_search(
|
||||
context_data = None
|
||||
get_context_data = True
|
||||
async for stream_chunk in api.local_search_streaming(
|
||||
root_dir=root_dir,
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
@ -163,7 +167,6 @@ def run_local_search(
|
||||
# not streaming
|
||||
response, context_data = asyncio.run(
|
||||
api.local_search(
|
||||
root_dir=root_dir,
|
||||
config=config,
|
||||
nodes=final_nodes,
|
||||
entities=final_entities,
|
||||
@ -180,77 +183,3 @@ def run_local_search(
|
||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||
# External users should use the API directly to get the response and context data.
|
||||
return response, context_data
|
||||
|
||||
|
||||
def _configure_paths_and_settings(
|
||||
data_dir: str | None,
|
||||
root_dir: str | None,
|
||||
config_filepath: str | None,
|
||||
) -> tuple[str, str | None, GraphRagConfig]:
|
||||
config = _create_graphrag_config(root_dir, config_filepath)
|
||||
if data_dir is None and root_dir is None:
|
||||
msg = "Either data_dir or root_dir must be provided."
|
||||
raise ValueError(msg)
|
||||
if data_dir is None:
|
||||
base_dir = Path(str(root_dir)) / config.storage.base_dir
|
||||
data_dir = str(resolve_timestamp_path(base_dir))
|
||||
return data_dir, root_dir, config
|
||||
|
||||
|
||||
def _infer_data_dir(root: str) -> str:
|
||||
output = Path(root) / "output"
|
||||
# use the latest data-run folder
|
||||
if output.exists():
|
||||
expr = re.compile(r"\d{8}-\d{6}")
|
||||
filtered = [f for f in output.iterdir() if f.is_dir() and expr.match(f.name)]
|
||||
folders = sorted(filtered, key=lambda f: f.name, reverse=True)
|
||||
if len(folders) > 0:
|
||||
folder = folders[0]
|
||||
return str((folder / "artifacts").absolute())
|
||||
msg = f"Could not infer data directory from root={root}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _create_graphrag_config(
|
||||
root: str | None,
|
||||
config_filepath: str | None,
|
||||
) -> GraphRagConfig:
|
||||
"""Create a GraphRag configuration."""
|
||||
return _read_config_parameters(root or "./", config_filepath)
|
||||
|
||||
|
||||
def _read_config_parameters(root: str, config: str | None):
|
||||
_root = Path(root)
|
||||
settings_yaml = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix in [".yaml", ".yml"]
|
||||
else _root / "settings.yaml"
|
||||
)
|
||||
if not settings_yaml.exists():
|
||||
settings_yaml = _root / "settings.yml"
|
||||
|
||||
if settings_yaml.exists():
|
||||
reporter.info(f"Reading settings from {settings_yaml}")
|
||||
with settings_yaml.open(
|
||||
"rb",
|
||||
) as file:
|
||||
import yaml
|
||||
|
||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
settings_json = (
|
||||
Path(config)
|
||||
if config and Path(config).suffix == ".json"
|
||||
else _root / "settings.json"
|
||||
)
|
||||
if settings_json.exists():
|
||||
reporter.info(f"Reading settings from {settings_json}")
|
||||
with settings_json.open("rb") as file:
|
||||
import json
|
||||
|
||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
||||
return create_graphrag_config(data, root)
|
||||
|
||||
reporter.info("Reading settings from environment variables")
|
||||
return create_graphrag_config(root_dir=root)
|
||||
|
||||
@ -123,6 +123,7 @@ _convert_global_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/
|
||||
_semversioner_release = "semversioner release"
|
||||
_semversioner_changelog = "semversioner changelog > CHANGELOG.md"
|
||||
_semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)"
|
||||
semversioner_add = "semversioner add-change"
|
||||
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
|
||||
check_format = 'ruff format . --check --preview'
|
||||
fix = "ruff --preview check --fix ."
|
||||
|
||||
33
tests/unit/config/test_resolve_timestamp_path.py
Normal file
33
tests/unit/config/test_resolve_timestamp_path.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
||||
|
||||
|
||||
def test_resolve_timestamp_path_no_timestamp_with_run_id():
|
||||
path = Path("path/to/data")
|
||||
result = resolve_timestamp_path(path, "20240812-121000")
|
||||
assert result == path
|
||||
|
||||
|
||||
def test_resolve_timestamp_path_no_timestamp_without_run_id():
|
||||
path = Path("path/to/data")
|
||||
result = resolve_timestamp_path(path)
|
||||
assert result == path
|
||||
|
||||
|
||||
def test_resolve_timestamp_path_with_timestamp_and_run_id():
|
||||
path = Path("some/path/${timestamp}/data")
|
||||
expected = Path("some/path/20240812/data")
|
||||
result = resolve_timestamp_path(path, "20240812")
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_resolve_timestamp_path_with_timestamp_and_inferred_directory():
|
||||
cwd = Path(__file__).parent
|
||||
path = cwd / "fixtures/timestamp_dirs/${timestamp}/data"
|
||||
expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data"
|
||||
result = resolve_timestamp_path(path)
|
||||
assert result == expected
|
||||
@ -1,32 +0,0 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from graphrag.query.cli import _infer_data_dir
|
||||
|
||||
|
||||
def test_infer_data_dir():
|
||||
root = "./tests/unit/query/data/defaults"
|
||||
result = Path(_infer_data_dir(root))
|
||||
assert result.parts[-2] == "20240812-121000"
|
||||
|
||||
|
||||
def test_infer_data_dir_ignores_hidden_files():
|
||||
"""A hidden file, starting with '.', will naturally be selected as latest data directory."""
|
||||
root = "./tests/unit/query/data/hidden"
|
||||
result = Path(_infer_data_dir(root))
|
||||
assert result.parts[-2] == "20240812-121000"
|
||||
|
||||
|
||||
def test_infer_data_dir_ignores_non_numeric():
|
||||
root = "./tests/unit/query/data/non-numeric"
|
||||
result = Path(_infer_data_dir(root))
|
||||
assert result.parts[-2] == "20240812-121000"
|
||||
|
||||
|
||||
def test_infer_data_dir_throws_on_no_match():
|
||||
root = "./tests/unit/query/data/empty"
|
||||
with pytest.raises(ValueError): # noqa PT011 (this is what is actually thrown...)
|
||||
_infer_data_dir(root)
|
||||
Loading…
x
Reference in New Issue
Block a user