Merge from main

This commit is contained in:
Alonso Guevara 2024-09-03 16:34:52 -06:00
commit 41ea554fda
17 changed files with 176 additions and 285 deletions

View File

@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Consistent config loading. Resolves #99 and Resolves #1049"
}

View File

@ -3,6 +3,7 @@
"""The Indexing Engine default config package root.""" """The Indexing Engine default config package root."""
from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
from .create_graphrag_config import ( from .create_graphrag_config import (
create_graphrag_config, create_graphrag_config,
) )
@ -42,6 +43,8 @@ from .input_models import (
TextEmbeddingConfigInput, TextEmbeddingConfigInput,
UmapConfigInput, UmapConfigInput,
) )
from .load_config import load_config
from .logging import enable_logging_with_config
from .models import ( from .models import (
CacheConfig, CacheConfig,
ChunkingConfig, ChunkingConfig,
@ -65,6 +68,7 @@ from .models import (
UmapConfig, UmapConfig,
) )
from .read_dotenv import read_dotenv from .read_dotenv import read_dotenv
from .resolve_timestamp_path import resolve_timestamp_path
__all__ = [ __all__ = [
"ApiKeyMissingError", "ApiKeyMissingError",
@ -119,5 +123,10 @@ __all__ = [
"UmapConfig", "UmapConfig",
"UmapConfigInput", "UmapConfigInput",
"create_graphrag_config", "create_graphrag_config",
"enable_logging_with_config",
"load_config",
"load_config_from_file",
"read_dotenv", "read_dotenv",
"resolve_timestamp_path",
"search_for_config_in_root_dir",
] ]

View File

@ -9,13 +9,13 @@ from pathlib import Path
import yaml import yaml
from . import create_graphrag_config from .create_graphrag_config import create_graphrag_config
from .models.graph_rag_config import GraphRagConfig from .models.graph_rag_config import GraphRagConfig
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"] _default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
def resolve_config_path_with_root(root: str | Path) -> Path: def search_for_config_in_root_dir(root: str | Path) -> Path | None:
"""Resolve the config path from the given root directory. """Resolve the config path from the given root directory.
Parameters Parameters
@ -26,13 +26,9 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
Returns Returns
------- -------
Path Path | None
The resolved config file path. returns a Path if there is a config in the root directory
Otherwise returns None.
Raises
------
FileNotFoundError
If the config file is not found or cannot be resolved for the directory.
""" """
root = Path(root) root = Path(root)
@ -44,8 +40,7 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
if (root / file).is_file(): if (root / file).is_file():
return root / file return root / file
msg = f"Unable to resolve config file for parent directory: {root}" return None
raise FileNotFoundError(msg)
class ConfigFileLoader(ABC): class ConfigFileLoader(ABC):

View File

@ -0,0 +1,65 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Default method for loading config."""
from pathlib import Path
from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
from .create_graphrag_config import create_graphrag_config
from .models.graph_rag_config import GraphRagConfig
from .resolve_timestamp_path import resolve_timestamp_path
def load_config(
root_dir: str | Path,
config_filepath: str | Path | None = None,
run_id: str | None = None,
) -> GraphRagConfig:
"""Load configuration from a file or create a default configuration.
If a config file is not found the default configuration is created.
Parameters
----------
root_dir : str | Path
The root directory of the project. Will search for the config file in this directory.
config_filepath : str | Path | None
The path to the config file.
If None, searches for config file in root and
if not found creates a default configuration.
run_id : str | None
The run id to use for resolving timestamp_paths.
"""
root = Path(root_dir).resolve()
# If user specified a config file path then it is required
if config_filepath:
config_path = (root / config_filepath).resolve()
if not config_path.exists():
msg = f"Specified Config file not found: {config_path}"
raise FileNotFoundError(msg)
# Else optional resolve the config path from the root directory if it exists
config_path = search_for_config_in_root_dir(root)
if config_path:
config = load_config_from_file(config_path)
else:
config = create_graphrag_config(root_dir=str(root))
if run_id:
config.storage.base_dir = str(
resolve_timestamp_path((root / config.storage.base_dir).resolve(), run_id)
)
config.reporting.base_dir = str(
resolve_timestamp_path((root / config.reporting.base_dir).resolve(), run_id)
)
else:
config.storage.base_dir = str(
resolve_timestamp_path((root / config.storage.base_dir).resolve())
)
config.reporting.base_dir = str(
resolve_timestamp_path((root / config.reporting.base_dir).resolve())
)
return config

View File

@ -8,7 +8,6 @@ from pathlib import Path
from .enums import ReportingType from .enums import ReportingType
from .models.graph_rag_config import GraphRagConfig from .models.graph_rag_config import GraphRagConfig
from .resolve_timestamp_path import resolve_timestamp_path
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None: def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
@ -35,7 +34,7 @@ def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
def enable_logging_with_config( def enable_logging_with_config(
config: GraphRagConfig, timestamp_value: str, verbose: bool = False config: GraphRagConfig, verbose: bool = False
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""Enable logging to a file based on the config. """Enable logging to a file based on the config.
@ -56,10 +55,7 @@ def enable_logging_with_config(
(True, str) if logging was enabled. (True, str) if logging was enabled.
""" """
if config.reporting.type == ReportingType.file: if config.reporting.type == ReportingType.file:
log_path = resolve_timestamp_path( log_path = Path(config.reporting.base_dir) / "indexing-engine.log"
Path(config.root_dir) / config.reporting.base_dir / "indexing-engine.log",
timestamp_value,
)
enable_logging(log_path, verbose) enable_logging(log_path, verbose)
return (True, str(log_path)) return (True, str(log_path))
return (False, "") return (False, "")

View File

@ -63,11 +63,6 @@ if __name__ == "__main__":
help="Create an initial configuration in the given path.", help="Create an initial configuration in the given path.",
action="store_true", action="store_true",
) )
parser.add_argument(
"--overlay-defaults",
help="Overlay default configuration values on a provided configuration file (--config).",
action="store_true",
)
parser.add_argument( parser.add_argument(
"--skip-validations", "--skip-validations",
help="Skip any preflight validation. Useful when running no LLM steps.", help="Skip any preflight validation. Useful when running no LLM steps.",
@ -83,21 +78,17 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
if args.overlay_defaults and not args.config:
parser.error("--overlay-defaults requires --config")
index_cli( index_cli(
root=args.root, root_dir=args.root,
verbose=args.verbose or False, verbose=args.verbose or False,
resume=args.resume, resume=args.resume,
update_index_id=args.update_index, update_index_id=args.update_index,
memprofile=args.memprofile or False, memprofile=args.memprofile or False,
nocache=args.nocache or False, nocache=args.nocache or False,
reporter=args.reporter, reporter=args.reporter,
config=args.config, config_filepath=args.config,
emit=args.emit, emit=args.emit,
dryrun=args.dryrun or False, dryrun=args.dryrun or False,
init=args.init or False, init=args.init or False,
overlay_defaults=args.overlay_defaults or False,
skip_validations=args.skip_validations or False, skip_validations=args.skip_validations or False,
) )

View File

@ -8,9 +8,9 @@ WARNING: This API is under development and may undergo changes in future release
Backwards compatibility is not guaranteed at this time. Backwards compatibility is not guaranteed at this time.
""" """
from graphrag.config.enums import CacheType from pathlib import Path
from graphrag.config.models.graph_rag_config import GraphRagConfig
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path from graphrag.config import CacheType, GraphRagConfig
from .cache.noop_pipeline_cache import NoopPipelineCache from .cache.noop_pipeline_cache import NoopPipelineCache
from .create_pipeline_config import create_pipeline_config from .create_pipeline_config import create_pipeline_config
@ -50,11 +50,7 @@ async def build_index(
list[PipelineRunResult] list[PipelineRunResult]
The list of pipeline run results The list of pipeline run results
""" """
try: resume = Path(config.storage.base_dir).exists()
resolve_timestamp_path(config.storage.base_dir, run_id)
resume = True
except ValueError as _:
resume = False
pipeline_config = create_pipeline_config(config) pipeline_config = create_pipeline_config(config)
pipeline_cache = ( pipeline_cache = (
NoopPipelineCache() if config.cache.type == CacheType.none is None else None NoopPipelineCache() if config.cache.type == CacheType.none is None else None

View File

@ -11,13 +11,7 @@ import time
import warnings import warnings
from pathlib import Path from pathlib import Path
from graphrag.config import create_graphrag_config from graphrag.config import CacheType, enable_logging_with_config, load_config
from graphrag.config.config_file_loader import (
load_config_from_file,
resolve_config_path_with_root,
)
from graphrag.config.enums import CacheType
from graphrag.config.logging import enable_logging_with_config
from .api import build_index, update_index from .api import build_index, update_index
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
@ -103,7 +97,7 @@ def _register_signal_handlers(reporter: ProgressReporter):
def index_cli( def index_cli(
root: str, root_dir: str,
init: bool, init: bool,
verbose: bool, verbose: bool,
resume: str | None, resume: str | None,
@ -111,10 +105,9 @@ def index_cli(
memprofile: bool, memprofile: bool,
nocache: bool, nocache: bool,
reporter: str | None, reporter: str | None,
config: str | None, config_filepath: str | None,
emit: str | None, emit: str | None,
dryrun: bool, dryrun: bool,
overlay_defaults: bool,
skip_validations: bool, skip_validations: bool,
): ):
"""Run the pipeline with the given config.""" """Run the pipeline with the given config."""
@ -123,41 +116,30 @@ def index_cli(
run_id = resume or time.strftime("%Y%m%d-%H%M%S") run_id = resume or time.strftime("%Y%m%d-%H%M%S")
if init: if init:
_initialize_project_at(root, progress_reporter) _initialize_project_at(root_dir, progress_reporter)
sys.exit(0) sys.exit(0)
if overlay_defaults or config: root = Path(root_dir).resolve()
config_path = ( config = load_config(root, config_filepath, run_id)
Path(root) / config if config else resolve_config_path_with_root(root)
)
default_config = load_config_from_file(config_path)
else:
try:
config_path = resolve_config_path_with_root(root)
default_config = load_config_from_file(config_path)
except FileNotFoundError:
default_config = create_graphrag_config(root_dir=root)
if nocache: if nocache:
default_config.cache.type = CacheType.none config.cache.type = CacheType.none
enabled_logging, log_path = enable_logging_with_config( enabled_logging, log_path = enable_logging_with_config(config, verbose)
default_config, run_id, verbose
)
if enabled_logging: if enabled_logging:
info(f"Logging enabled at {log_path}", True) info(f"Logging enabled at {log_path}", True)
else: else:
info( info(
f"Logging not enabled for config {_redact(default_config.model_dump())}", f"Logging not enabled for config {_redact(config.model_dump())}",
True, True,
) )
if skip_validations: if skip_validations:
validate_config_names(progress_reporter, default_config) validate_config_names(progress_reporter, config)
info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose) info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
info( info(
f"Using default configuration: {_redact(default_config.model_dump())}", f"Using default configuration: {_redact(config.model_dump())}",
verbose, verbose,
) )
@ -169,20 +151,9 @@ def index_cli(
_register_signal_handlers(progress_reporter) _register_signal_handlers(progress_reporter)
if update_index_id:
outputs = asyncio.run(
update_index(
default_config,
memprofile,
update_index_id,
progress_reporter,
pipeline_emit,
)
)
else:
outputs = asyncio.run( outputs = asyncio.run(
build_index( build_index(
default_config, config,
run_id, run_id,
memprofile, memprofile,
progress_reporter, progress_reporter,

View File

@ -5,11 +5,11 @@
from pathlib import Path from pathlib import Path
from graphrag.config import load_config
from graphrag.index.progress import PrintProgressReporter from graphrag.index.progress import PrintProgressReporter
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
from graphrag.prompt_tune.loader import ( from graphrag.prompt_tune.loader import (
MIN_CHUNK_SIZE, MIN_CHUNK_SIZE,
read_config_parameters,
) )
from . import api from . import api
@ -53,11 +53,12 @@ async def prompt_tune(
- min_examples_required: The minimum number of examples required for entity extraction prompts. - min_examples_required: The minimum number of examples required for entity extraction prompts.
""" """
reporter = PrintProgressReporter("") reporter = PrintProgressReporter("")
graph_config = read_config_parameters(root, reporter, config) root_path = Path(root).resolve()
graph_config = load_config(root_path, config)
prompts = await api.generate_indexing_prompts( prompts = await api.generate_indexing_prompts(
config=graph_config, config=graph_config,
root=root, root=str(root_path),
chunk_size=chunk_size, chunk_size=chunk_size,
limit=limit, limit=limit,
selection_method=selection_method, selection_method=selection_method,
@ -70,7 +71,7 @@ async def prompt_tune(
k=k, k=k,
) )
output_path = Path(output) output_path = (root_path / output).resolve()
if output_path: if output_path:
reporter.info(f"Writing prompts to {output_path}") reporter.info(f"Writing prompts to {output_path}")
output_path.mkdir(parents=True, exist_ok=True) output_path.mkdir(parents=True, exist_ok=True)

View File

@ -3,12 +3,10 @@
"""Fine-tuning config and data loader module.""" """Fine-tuning config and data loader module."""
from .config import read_config_parameters
from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks
__all__ = [ __all__ = [
"MIN_CHUNK_OVERLAP", "MIN_CHUNK_OVERLAP",
"MIN_CHUNK_SIZE", "MIN_CHUNK_SIZE",
"load_docs_in_chunks", "load_docs_in_chunks",
"read_config_parameters",
] ]

View File

@ -1,61 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Config loading, parsing and handling module."""
from pathlib import Path
from graphrag.config import create_graphrag_config
from graphrag.index.progress.types import ProgressReporter
def read_config_parameters(
root: str, reporter: ProgressReporter, config: str | None = None
):
"""Read the configuration parameters from the settings file or environment variables.
Parameters
----------
- root: The root directory where the parameters are.
- reporter: The progress reporter.
- config: The path to the settings file.
"""
_root = Path(root)
settings_yaml = (
Path(config)
if config and Path(config).suffix in [".yaml", ".yml"]
else _root / "settings.yaml"
)
if not settings_yaml.exists():
settings_yaml = _root / "settings.yml"
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("rb") as file:
import yaml
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
settings_json = (
Path(config)
if config and Path(config).suffix == ".json"
else _root / "settings.json"
)
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open("rb") as file:
import yaml
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("rb") as file:
import json
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
reporter.info("Reading settings from environment variables")
return create_graphrag_config(root_dir=root)

View File

@ -24,8 +24,7 @@ from typing import Any
import pandas as pd import pandas as pd
from pydantic import validate_call from pydantic import validate_call
from graphrag.config.models.graph_rag_config import GraphRagConfig from graphrag.config import GraphRagConfig
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from graphrag.index.progress.types import PrintProgressReporter from graphrag.index.progress.types import PrintProgressReporter
from graphrag.model.entity import Entity from graphrag.model.entity import Entity
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001 from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
@ -149,7 +148,6 @@ async def global_search_streaming(
@validate_call(config={"arbitrary_types_allowed": True}) @validate_call(config={"arbitrary_types_allowed": True})
async def local_search( async def local_search(
root_dir: str | None,
config: GraphRagConfig, config: GraphRagConfig,
nodes: pd.DataFrame, nodes: pd.DataFrame,
entities: pd.DataFrame, entities: pd.DataFrame,
@ -196,9 +194,8 @@ async def local_search(
_entities = read_indexer_entities(nodes, entities, community_level) _entities = read_indexer_entities(nodes, entities, community_level)
base_dir = Path(str(root_dir)) / config.storage.base_dir lancedb_dir = Path(config.storage.base_dir) / "lancedb"
resolved_base_dir = resolve_timestamp_path(base_dir)
lancedb_dir = resolved_base_dir / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)}) vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = _get_embedding_description_store( description_embedding_store = _get_embedding_description_store(
entities=_entities, entities=_entities,
@ -227,7 +224,6 @@ async def local_search(
@validate_call(config={"arbitrary_types_allowed": True}) @validate_call(config={"arbitrary_types_allowed": True})
async def local_search_streaming( async def local_search_streaming(
root_dir: str | None,
config: GraphRagConfig, config: GraphRagConfig,
nodes: pd.DataFrame, nodes: pd.DataFrame,
entities: pd.DataFrame, entities: pd.DataFrame,
@ -271,9 +267,8 @@ async def local_search_streaming(
_entities = read_indexer_entities(nodes, entities, community_level) _entities = read_indexer_entities(nodes, entities, community_level)
base_dir = Path(str(root_dir)) / config.storage.base_dir lancedb_dir = lancedb_dir = Path(config.storage.base_dir) / "lancedb"
resolved_base_dir = resolve_timestamp_path(base_dir)
lancedb_dir = resolved_base_dir / "lancedb"
vector_store_args.update({"db_uri": str(lancedb_dir)}) vector_store_args.update({"db_uri": str(lancedb_dir)})
description_embedding_store = _get_embedding_description_store( description_embedding_store = _get_embedding_description_store(
entities=_entities, entities=_entities,

View File

@ -4,17 +4,12 @@
"""Command line interface for the query module.""" """Command line interface for the query module."""
import asyncio import asyncio
import re
import sys import sys
from pathlib import Path from pathlib import Path
import pandas as pd import pandas as pd
from graphrag.config import ( from graphrag.config import load_config, resolve_timestamp_path
GraphRagConfig,
create_graphrag_config,
)
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
from graphrag.index.progress import PrintProgressReporter from graphrag.index.progress import PrintProgressReporter
from . import api from . import api
@ -25,7 +20,7 @@ reporter = PrintProgressReporter("")
def run_global_search( def run_global_search(
config_filepath: str | None, config_filepath: str | None,
data_dir: str | None, data_dir: str | None,
root_dir: str | None, root_dir: str,
community_level: int, community_level: int,
response_type: str, response_type: str,
streaming: bool, streaming: bool,
@ -35,10 +30,15 @@ def run_global_search(
Loads index files required for global search and calls the Query API. Loads index files required for global search and calls the Query API.
""" """
data_dir, root_dir, config = _configure_paths_and_settings( root = Path(root_dir).resolve()
data_dir, root_dir, config_filepath config = load_config(root, config_filepath)
if data_dir:
config.storage.base_dir = str(
resolve_timestamp_path((root / data_dir).resolve())
) )
data_path = Path(data_dir)
data_path = Path(config.storage.base_dir).resolve()
final_nodes: pd.DataFrame = pd.read_parquet( final_nodes: pd.DataFrame = pd.read_parquet(
data_path / "create_final_nodes.parquet" data_path / "create_final_nodes.parquet"
@ -98,7 +98,7 @@ def run_global_search(
def run_local_search( def run_local_search(
config_filepath: str | None, config_filepath: str | None,
data_dir: str | None, data_dir: str | None,
root_dir: str | None, root_dir: str,
community_level: int, community_level: int,
response_type: str, response_type: str,
streaming: bool, streaming: bool,
@ -108,10 +108,15 @@ def run_local_search(
Loads index files required for local search and calls the Query API. Loads index files required for local search and calls the Query API.
""" """
data_dir, root_dir, config = _configure_paths_and_settings( root = Path(root_dir).resolve()
data_dir, root_dir, config_filepath config = load_config(root, config_filepath)
if data_dir:
config.storage.base_dir = str(
resolve_timestamp_path((root / data_dir).resolve())
) )
data_path = Path(data_dir)
data_path = Path(config.storage.base_dir).resolve()
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet") final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
final_community_reports = pd.read_parquet( final_community_reports = pd.read_parquet(
@ -137,7 +142,6 @@ def run_local_search(
context_data = None context_data = None
get_context_data = True get_context_data = True
async for stream_chunk in api.local_search_streaming( async for stream_chunk in api.local_search_streaming(
root_dir=root_dir,
config=config, config=config,
nodes=final_nodes, nodes=final_nodes,
entities=final_entities, entities=final_entities,
@ -163,7 +167,6 @@ def run_local_search(
# not streaming # not streaming
response, context_data = asyncio.run( response, context_data = asyncio.run(
api.local_search( api.local_search(
root_dir=root_dir,
config=config, config=config,
nodes=final_nodes, nodes=final_nodes,
entities=final_entities, entities=final_entities,
@ -180,77 +183,3 @@ def run_local_search(
# NOTE: we return the response and context data here purely as a complete demonstration of the API. # NOTE: we return the response and context data here purely as a complete demonstration of the API.
# External users should use the API directly to get the response and context data. # External users should use the API directly to get the response and context data.
return response, context_data return response, context_data
def _configure_paths_and_settings(
data_dir: str | None,
root_dir: str | None,
config_filepath: str | None,
) -> tuple[str, str | None, GraphRagConfig]:
config = _create_graphrag_config(root_dir, config_filepath)
if data_dir is None and root_dir is None:
msg = "Either data_dir or root_dir must be provided."
raise ValueError(msg)
if data_dir is None:
base_dir = Path(str(root_dir)) / config.storage.base_dir
data_dir = str(resolve_timestamp_path(base_dir))
return data_dir, root_dir, config
def _infer_data_dir(root: str) -> str:
output = Path(root) / "output"
# use the latest data-run folder
if output.exists():
expr = re.compile(r"\d{8}-\d{6}")
filtered = [f for f in output.iterdir() if f.is_dir() and expr.match(f.name)]
folders = sorted(filtered, key=lambda f: f.name, reverse=True)
if len(folders) > 0:
folder = folders[0]
return str((folder / "artifacts").absolute())
msg = f"Could not infer data directory from root={root}"
raise ValueError(msg)
def _create_graphrag_config(
root: str | None,
config_filepath: str | None,
) -> GraphRagConfig:
"""Create a GraphRag configuration."""
return _read_config_parameters(root or "./", config_filepath)
def _read_config_parameters(root: str, config: str | None):
_root = Path(root)
settings_yaml = (
Path(config)
if config and Path(config).suffix in [".yaml", ".yml"]
else _root / "settings.yaml"
)
if not settings_yaml.exists():
settings_yaml = _root / "settings.yml"
if settings_yaml.exists():
reporter.info(f"Reading settings from {settings_yaml}")
with settings_yaml.open(
"rb",
) as file:
import yaml
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
settings_json = (
Path(config)
if config and Path(config).suffix == ".json"
else _root / "settings.json"
)
if settings_json.exists():
reporter.info(f"Reading settings from {settings_json}")
with settings_json.open("rb") as file:
import json
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
return create_graphrag_config(data, root)
reporter.info("Reading settings from environment variables")
return create_graphrag_config(root_dir=root)

View File

@ -123,6 +123,7 @@ _convert_global_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/
_semversioner_release = "semversioner release" _semversioner_release = "semversioner release"
_semversioner_changelog = "semversioner changelog > CHANGELOG.md" _semversioner_changelog = "semversioner changelog > CHANGELOG.md"
_semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)" _semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)"
semversioner_add = "semversioner add-change"
coverage_report = 'coverage report --omit "**/tests/**" --show-missing' coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
check_format = 'ruff format . --check --preview' check_format = 'ruff format . --check --preview'
fix = "ruff --preview check --fix ." fix = "ruff --preview check --fix ."

View File

@ -0,0 +1,33 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from pathlib import Path
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
def test_resolve_timestamp_path_no_timestamp_with_run_id():
path = Path("path/to/data")
result = resolve_timestamp_path(path, "20240812-121000")
assert result == path
def test_resolve_timestamp_path_no_timestamp_without_run_id():
path = Path("path/to/data")
result = resolve_timestamp_path(path)
assert result == path
def test_resolve_timestamp_path_with_timestamp_and_run_id():
path = Path("some/path/${timestamp}/data")
expected = Path("some/path/20240812/data")
result = resolve_timestamp_path(path, "20240812")
assert result == expected
def test_resolve_timestamp_path_with_timestamp_and_inferred_directory():
cwd = Path(__file__).parent
path = cwd / "fixtures/timestamp_dirs/${timestamp}/data"
expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data"
result = resolve_timestamp_path(path)
assert result == expected

View File

@ -1,32 +0,0 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from pathlib import Path
import pytest
from graphrag.query.cli import _infer_data_dir
def test_infer_data_dir():
root = "./tests/unit/query/data/defaults"
result = Path(_infer_data_dir(root))
assert result.parts[-2] == "20240812-121000"
def test_infer_data_dir_ignores_hidden_files():
"""A hidden file, starting with '.', will naturally be selected as latest data directory."""
root = "./tests/unit/query/data/hidden"
result = Path(_infer_data_dir(root))
assert result.parts[-2] == "20240812-121000"
def test_infer_data_dir_ignores_non_numeric():
root = "./tests/unit/query/data/non-numeric"
result = Path(_infer_data_dir(root))
assert result.parts[-2] == "20240812-121000"
def test_infer_data_dir_throws_on_no_match():
root = "./tests/unit/query/data/empty"
with pytest.raises(ValueError): # noqa PT011 (this is what is actually thrown...)
_infer_data_dir(root)