mirror of
https://github.com/microsoft/graphrag.git
synced 2025-12-14 16:47:18 +00:00
Merge from main
This commit is contained in:
commit
41ea554fda
@ -0,0 +1,4 @@
|
|||||||
|
{
|
||||||
|
"type": "patch",
|
||||||
|
"description": "Consistent config loading. Resolves #99 and Resolves #1049"
|
||||||
|
}
|
||||||
@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
"""The Indexing Engine default config package root."""
|
"""The Indexing Engine default config package root."""
|
||||||
|
|
||||||
|
from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
|
||||||
from .create_graphrag_config import (
|
from .create_graphrag_config import (
|
||||||
create_graphrag_config,
|
create_graphrag_config,
|
||||||
)
|
)
|
||||||
@ -42,6 +43,8 @@ from .input_models import (
|
|||||||
TextEmbeddingConfigInput,
|
TextEmbeddingConfigInput,
|
||||||
UmapConfigInput,
|
UmapConfigInput,
|
||||||
)
|
)
|
||||||
|
from .load_config import load_config
|
||||||
|
from .logging import enable_logging_with_config
|
||||||
from .models import (
|
from .models import (
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
ChunkingConfig,
|
ChunkingConfig,
|
||||||
@ -65,6 +68,7 @@ from .models import (
|
|||||||
UmapConfig,
|
UmapConfig,
|
||||||
)
|
)
|
||||||
from .read_dotenv import read_dotenv
|
from .read_dotenv import read_dotenv
|
||||||
|
from .resolve_timestamp_path import resolve_timestamp_path
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ApiKeyMissingError",
|
"ApiKeyMissingError",
|
||||||
@ -119,5 +123,10 @@ __all__ = [
|
|||||||
"UmapConfig",
|
"UmapConfig",
|
||||||
"UmapConfigInput",
|
"UmapConfigInput",
|
||||||
"create_graphrag_config",
|
"create_graphrag_config",
|
||||||
|
"enable_logging_with_config",
|
||||||
|
"load_config",
|
||||||
|
"load_config_from_file",
|
||||||
"read_dotenv",
|
"read_dotenv",
|
||||||
|
"resolve_timestamp_path",
|
||||||
|
"search_for_config_in_root_dir",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -9,13 +9,13 @@ from pathlib import Path
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from . import create_graphrag_config
|
from .create_graphrag_config import create_graphrag_config
|
||||||
from .models.graph_rag_config import GraphRagConfig
|
from .models.graph_rag_config import GraphRagConfig
|
||||||
|
|
||||||
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
|
_default_config_files = ["settings.yaml", "settings.yml", "settings.json"]
|
||||||
|
|
||||||
|
|
||||||
def resolve_config_path_with_root(root: str | Path) -> Path:
|
def search_for_config_in_root_dir(root: str | Path) -> Path | None:
|
||||||
"""Resolve the config path from the given root directory.
|
"""Resolve the config path from the given root directory.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -26,13 +26,9 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Path
|
Path | None
|
||||||
The resolved config file path.
|
returns a Path if there is a config in the root directory
|
||||||
|
Otherwise returns None.
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file is not found or cannot be resolved for the directory.
|
|
||||||
"""
|
"""
|
||||||
root = Path(root)
|
root = Path(root)
|
||||||
|
|
||||||
@ -44,8 +40,7 @@ def resolve_config_path_with_root(root: str | Path) -> Path:
|
|||||||
if (root / file).is_file():
|
if (root / file).is_file():
|
||||||
return root / file
|
return root / file
|
||||||
|
|
||||||
msg = f"Unable to resolve config file for parent directory: {root}"
|
return None
|
||||||
raise FileNotFoundError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigFileLoader(ABC):
|
class ConfigFileLoader(ABC):
|
||||||
|
|||||||
65
graphrag/config/load_config.py
Normal file
65
graphrag/config/load_config.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
"""Default method for loading config."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .config_file_loader import load_config_from_file, search_for_config_in_root_dir
|
||||||
|
from .create_graphrag_config import create_graphrag_config
|
||||||
|
from .models.graph_rag_config import GraphRagConfig
|
||||||
|
from .resolve_timestamp_path import resolve_timestamp_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(
|
||||||
|
root_dir: str | Path,
|
||||||
|
config_filepath: str | Path | None = None,
|
||||||
|
run_id: str | None = None,
|
||||||
|
) -> GraphRagConfig:
|
||||||
|
"""Load configuration from a file or create a default configuration.
|
||||||
|
|
||||||
|
If a config file is not found the default configuration is created.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
root_dir : str | Path
|
||||||
|
The root directory of the project. Will search for the config file in this directory.
|
||||||
|
config_filepath : str | Path | None
|
||||||
|
The path to the config file.
|
||||||
|
If None, searches for config file in root and
|
||||||
|
if not found creates a default configuration.
|
||||||
|
run_id : str | None
|
||||||
|
The run id to use for resolving timestamp_paths.
|
||||||
|
"""
|
||||||
|
root = Path(root_dir).resolve()
|
||||||
|
|
||||||
|
# If user specified a config file path then it is required
|
||||||
|
if config_filepath:
|
||||||
|
config_path = (root / config_filepath).resolve()
|
||||||
|
if not config_path.exists():
|
||||||
|
msg = f"Specified Config file not found: {config_path}"
|
||||||
|
raise FileNotFoundError(msg)
|
||||||
|
|
||||||
|
# Else optional resolve the config path from the root directory if it exists
|
||||||
|
config_path = search_for_config_in_root_dir(root)
|
||||||
|
if config_path:
|
||||||
|
config = load_config_from_file(config_path)
|
||||||
|
else:
|
||||||
|
config = create_graphrag_config(root_dir=str(root))
|
||||||
|
|
||||||
|
if run_id:
|
||||||
|
config.storage.base_dir = str(
|
||||||
|
resolve_timestamp_path((root / config.storage.base_dir).resolve(), run_id)
|
||||||
|
)
|
||||||
|
config.reporting.base_dir = str(
|
||||||
|
resolve_timestamp_path((root / config.reporting.base_dir).resolve(), run_id)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config.storage.base_dir = str(
|
||||||
|
resolve_timestamp_path((root / config.storage.base_dir).resolve())
|
||||||
|
)
|
||||||
|
config.reporting.base_dir = str(
|
||||||
|
resolve_timestamp_path((root / config.reporting.base_dir).resolve())
|
||||||
|
)
|
||||||
|
|
||||||
|
return config
|
||||||
@ -8,7 +8,6 @@ from pathlib import Path
|
|||||||
|
|
||||||
from .enums import ReportingType
|
from .enums import ReportingType
|
||||||
from .models.graph_rag_config import GraphRagConfig
|
from .models.graph_rag_config import GraphRagConfig
|
||||||
from .resolve_timestamp_path import resolve_timestamp_path
|
|
||||||
|
|
||||||
|
|
||||||
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
|
def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
|
||||||
@ -35,7 +34,7 @@ def enable_logging(log_filepath: str | Path, verbose: bool = False) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def enable_logging_with_config(
|
def enable_logging_with_config(
|
||||||
config: GraphRagConfig, timestamp_value: str, verbose: bool = False
|
config: GraphRagConfig, verbose: bool = False
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""Enable logging to a file based on the config.
|
"""Enable logging to a file based on the config.
|
||||||
|
|
||||||
@ -56,10 +55,7 @@ def enable_logging_with_config(
|
|||||||
(True, str) if logging was enabled.
|
(True, str) if logging was enabled.
|
||||||
"""
|
"""
|
||||||
if config.reporting.type == ReportingType.file:
|
if config.reporting.type == ReportingType.file:
|
||||||
log_path = resolve_timestamp_path(
|
log_path = Path(config.reporting.base_dir) / "indexing-engine.log"
|
||||||
Path(config.root_dir) / config.reporting.base_dir / "indexing-engine.log",
|
|
||||||
timestamp_value,
|
|
||||||
)
|
|
||||||
enable_logging(log_path, verbose)
|
enable_logging(log_path, verbose)
|
||||||
return (True, str(log_path))
|
return (True, str(log_path))
|
||||||
return (False, "")
|
return (False, "")
|
||||||
|
|||||||
@ -63,11 +63,6 @@ if __name__ == "__main__":
|
|||||||
help="Create an initial configuration in the given path.",
|
help="Create an initial configuration in the given path.",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--overlay-defaults",
|
|
||||||
help="Overlay default configuration values on a provided configuration file (--config).",
|
|
||||||
action="store_true",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-validations",
|
"--skip-validations",
|
||||||
help="Skip any preflight validation. Useful when running no LLM steps.",
|
help="Skip any preflight validation. Useful when running no LLM steps.",
|
||||||
@ -83,21 +78,17 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.overlay_defaults and not args.config:
|
|
||||||
parser.error("--overlay-defaults requires --config")
|
|
||||||
|
|
||||||
index_cli(
|
index_cli(
|
||||||
root=args.root,
|
root_dir=args.root,
|
||||||
verbose=args.verbose or False,
|
verbose=args.verbose or False,
|
||||||
resume=args.resume,
|
resume=args.resume,
|
||||||
update_index_id=args.update_index,
|
update_index_id=args.update_index,
|
||||||
memprofile=args.memprofile or False,
|
memprofile=args.memprofile or False,
|
||||||
nocache=args.nocache or False,
|
nocache=args.nocache or False,
|
||||||
reporter=args.reporter,
|
reporter=args.reporter,
|
||||||
config=args.config,
|
config_filepath=args.config,
|
||||||
emit=args.emit,
|
emit=args.emit,
|
||||||
dryrun=args.dryrun or False,
|
dryrun=args.dryrun or False,
|
||||||
init=args.init or False,
|
init=args.init or False,
|
||||||
overlay_defaults=args.overlay_defaults or False,
|
|
||||||
skip_validations=args.skip_validations or False,
|
skip_validations=args.skip_validations or False,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -8,9 +8,9 @@ WARNING: This API is under development and may undergo changes in future release
|
|||||||
Backwards compatibility is not guaranteed at this time.
|
Backwards compatibility is not guaranteed at this time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from graphrag.config.enums import CacheType
|
from pathlib import Path
|
||||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
|
||||||
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
from graphrag.config import CacheType, GraphRagConfig
|
||||||
|
|
||||||
from .cache.noop_pipeline_cache import NoopPipelineCache
|
from .cache.noop_pipeline_cache import NoopPipelineCache
|
||||||
from .create_pipeline_config import create_pipeline_config
|
from .create_pipeline_config import create_pipeline_config
|
||||||
@ -50,11 +50,7 @@ async def build_index(
|
|||||||
list[PipelineRunResult]
|
list[PipelineRunResult]
|
||||||
The list of pipeline run results
|
The list of pipeline run results
|
||||||
"""
|
"""
|
||||||
try:
|
resume = Path(config.storage.base_dir).exists()
|
||||||
resolve_timestamp_path(config.storage.base_dir, run_id)
|
|
||||||
resume = True
|
|
||||||
except ValueError as _:
|
|
||||||
resume = False
|
|
||||||
pipeline_config = create_pipeline_config(config)
|
pipeline_config = create_pipeline_config(config)
|
||||||
pipeline_cache = (
|
pipeline_cache = (
|
||||||
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
|
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
|
||||||
|
|||||||
@ -11,13 +11,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from graphrag.config import create_graphrag_config
|
from graphrag.config import CacheType, enable_logging_with_config, load_config
|
||||||
from graphrag.config.config_file_loader import (
|
|
||||||
load_config_from_file,
|
|
||||||
resolve_config_path_with_root,
|
|
||||||
)
|
|
||||||
from graphrag.config.enums import CacheType
|
|
||||||
from graphrag.config.logging import enable_logging_with_config
|
|
||||||
|
|
||||||
from .api import build_index, update_index
|
from .api import build_index, update_index
|
||||||
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
|
from .graph.extractors.claims.prompts import CLAIM_EXTRACTION_PROMPT
|
||||||
@ -103,7 +97,7 @@ def _register_signal_handlers(reporter: ProgressReporter):
|
|||||||
|
|
||||||
|
|
||||||
def index_cli(
|
def index_cli(
|
||||||
root: str,
|
root_dir: str,
|
||||||
init: bool,
|
init: bool,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
resume: str | None,
|
resume: str | None,
|
||||||
@ -111,10 +105,9 @@ def index_cli(
|
|||||||
memprofile: bool,
|
memprofile: bool,
|
||||||
nocache: bool,
|
nocache: bool,
|
||||||
reporter: str | None,
|
reporter: str | None,
|
||||||
config: str | None,
|
config_filepath: str | None,
|
||||||
emit: str | None,
|
emit: str | None,
|
||||||
dryrun: bool,
|
dryrun: bool,
|
||||||
overlay_defaults: bool,
|
|
||||||
skip_validations: bool,
|
skip_validations: bool,
|
||||||
):
|
):
|
||||||
"""Run the pipeline with the given config."""
|
"""Run the pipeline with the given config."""
|
||||||
@ -123,41 +116,30 @@ def index_cli(
|
|||||||
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
|
run_id = resume or time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
|
||||||
if init:
|
if init:
|
||||||
_initialize_project_at(root, progress_reporter)
|
_initialize_project_at(root_dir, progress_reporter)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
if overlay_defaults or config:
|
root = Path(root_dir).resolve()
|
||||||
config_path = (
|
config = load_config(root, config_filepath, run_id)
|
||||||
Path(root) / config if config else resolve_config_path_with_root(root)
|
|
||||||
)
|
|
||||||
default_config = load_config_from_file(config_path)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
config_path = resolve_config_path_with_root(root)
|
|
||||||
default_config = load_config_from_file(config_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
default_config = create_graphrag_config(root_dir=root)
|
|
||||||
|
|
||||||
if nocache:
|
if nocache:
|
||||||
default_config.cache.type = CacheType.none
|
config.cache.type = CacheType.none
|
||||||
|
|
||||||
enabled_logging, log_path = enable_logging_with_config(
|
enabled_logging, log_path = enable_logging_with_config(config, verbose)
|
||||||
default_config, run_id, verbose
|
|
||||||
)
|
|
||||||
if enabled_logging:
|
if enabled_logging:
|
||||||
info(f"Logging enabled at {log_path}", True)
|
info(f"Logging enabled at {log_path}", True)
|
||||||
else:
|
else:
|
||||||
info(
|
info(
|
||||||
f"Logging not enabled for config {_redact(default_config.model_dump())}",
|
f"Logging not enabled for config {_redact(config.model_dump())}",
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if skip_validations:
|
if skip_validations:
|
||||||
validate_config_names(progress_reporter, default_config)
|
validate_config_names(progress_reporter, config)
|
||||||
|
|
||||||
info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
|
info(f"Starting pipeline run for: {run_id}, {dryrun=}", verbose)
|
||||||
info(
|
info(
|
||||||
f"Using default configuration: {_redact(default_config.model_dump())}",
|
f"Using default configuration: {_redact(config.model_dump())}",
|
||||||
verbose,
|
verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -169,20 +151,9 @@ def index_cli(
|
|||||||
|
|
||||||
_register_signal_handlers(progress_reporter)
|
_register_signal_handlers(progress_reporter)
|
||||||
|
|
||||||
if update_index_id:
|
|
||||||
outputs = asyncio.run(
|
|
||||||
update_index(
|
|
||||||
default_config,
|
|
||||||
memprofile,
|
|
||||||
update_index_id,
|
|
||||||
progress_reporter,
|
|
||||||
pipeline_emit,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
outputs = asyncio.run(
|
outputs = asyncio.run(
|
||||||
build_index(
|
build_index(
|
||||||
default_config,
|
config,
|
||||||
run_id,
|
run_id,
|
||||||
memprofile,
|
memprofile,
|
||||||
progress_reporter,
|
progress_reporter,
|
||||||
|
|||||||
@ -5,11 +5,11 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from graphrag.config import load_config
|
||||||
from graphrag.index.progress import PrintProgressReporter
|
from graphrag.index.progress import PrintProgressReporter
|
||||||
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
|
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
|
||||||
from graphrag.prompt_tune.loader import (
|
from graphrag.prompt_tune.loader import (
|
||||||
MIN_CHUNK_SIZE,
|
MIN_CHUNK_SIZE,
|
||||||
read_config_parameters,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
@ -53,11 +53,12 @@ async def prompt_tune(
|
|||||||
- min_examples_required: The minimum number of examples required for entity extraction prompts.
|
- min_examples_required: The minimum number of examples required for entity extraction prompts.
|
||||||
"""
|
"""
|
||||||
reporter = PrintProgressReporter("")
|
reporter = PrintProgressReporter("")
|
||||||
graph_config = read_config_parameters(root, reporter, config)
|
root_path = Path(root).resolve()
|
||||||
|
graph_config = load_config(root_path, config)
|
||||||
|
|
||||||
prompts = await api.generate_indexing_prompts(
|
prompts = await api.generate_indexing_prompts(
|
||||||
config=graph_config,
|
config=graph_config,
|
||||||
root=root,
|
root=str(root_path),
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
selection_method=selection_method,
|
selection_method=selection_method,
|
||||||
@ -70,7 +71,7 @@ async def prompt_tune(
|
|||||||
k=k,
|
k=k,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_path = Path(output)
|
output_path = (root_path / output).resolve()
|
||||||
if output_path:
|
if output_path:
|
||||||
reporter.info(f"Writing prompts to {output_path}")
|
reporter.info(f"Writing prompts to {output_path}")
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
@ -3,12 +3,10 @@
|
|||||||
|
|
||||||
"""Fine-tuning config and data loader module."""
|
"""Fine-tuning config and data loader module."""
|
||||||
|
|
||||||
from .config import read_config_parameters
|
|
||||||
from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks
|
from .input import MIN_CHUNK_OVERLAP, MIN_CHUNK_SIZE, load_docs_in_chunks
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MIN_CHUNK_OVERLAP",
|
"MIN_CHUNK_OVERLAP",
|
||||||
"MIN_CHUNK_SIZE",
|
"MIN_CHUNK_SIZE",
|
||||||
"load_docs_in_chunks",
|
"load_docs_in_chunks",
|
||||||
"read_config_parameters",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,61 +0,0 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License
|
|
||||||
|
|
||||||
"""Config loading, parsing and handling module."""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from graphrag.config import create_graphrag_config
|
|
||||||
from graphrag.index.progress.types import ProgressReporter
|
|
||||||
|
|
||||||
|
|
||||||
def read_config_parameters(
|
|
||||||
root: str, reporter: ProgressReporter, config: str | None = None
|
|
||||||
):
|
|
||||||
"""Read the configuration parameters from the settings file or environment variables.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
- root: The root directory where the parameters are.
|
|
||||||
- reporter: The progress reporter.
|
|
||||||
- config: The path to the settings file.
|
|
||||||
"""
|
|
||||||
_root = Path(root)
|
|
||||||
settings_yaml = (
|
|
||||||
Path(config)
|
|
||||||
if config and Path(config).suffix in [".yaml", ".yml"]
|
|
||||||
else _root / "settings.yaml"
|
|
||||||
)
|
|
||||||
if not settings_yaml.exists():
|
|
||||||
settings_yaml = _root / "settings.yml"
|
|
||||||
if settings_yaml.exists():
|
|
||||||
reporter.info(f"Reading settings from {settings_yaml}")
|
|
||||||
with settings_yaml.open("rb") as file:
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
|
||||||
return create_graphrag_config(data, root)
|
|
||||||
|
|
||||||
settings_json = (
|
|
||||||
Path(config)
|
|
||||||
if config and Path(config).suffix == ".json"
|
|
||||||
else _root / "settings.json"
|
|
||||||
)
|
|
||||||
if settings_yaml.exists():
|
|
||||||
reporter.info(f"Reading settings from {settings_yaml}")
|
|
||||||
with settings_yaml.open("rb") as file:
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
|
||||||
return create_graphrag_config(data, root)
|
|
||||||
|
|
||||||
if settings_json.exists():
|
|
||||||
reporter.info(f"Reading settings from {settings_json}")
|
|
||||||
with settings_json.open("rb") as file:
|
|
||||||
import json
|
|
||||||
|
|
||||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
|
||||||
return create_graphrag_config(data, root)
|
|
||||||
|
|
||||||
reporter.info("Reading settings from environment variables")
|
|
||||||
return create_graphrag_config(root_dir=root)
|
|
||||||
@ -24,8 +24,7 @@ from typing import Any
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import validate_call
|
from pydantic import validate_call
|
||||||
|
|
||||||
from graphrag.config.models.graph_rag_config import GraphRagConfig
|
from graphrag.config import GraphRagConfig
|
||||||
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
|
||||||
from graphrag.index.progress.types import PrintProgressReporter
|
from graphrag.index.progress.types import PrintProgressReporter
|
||||||
from graphrag.model.entity import Entity
|
from graphrag.model.entity import Entity
|
||||||
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
|
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
|
||||||
@ -149,7 +148,6 @@ async def global_search_streaming(
|
|||||||
|
|
||||||
@validate_call(config={"arbitrary_types_allowed": True})
|
@validate_call(config={"arbitrary_types_allowed": True})
|
||||||
async def local_search(
|
async def local_search(
|
||||||
root_dir: str | None,
|
|
||||||
config: GraphRagConfig,
|
config: GraphRagConfig,
|
||||||
nodes: pd.DataFrame,
|
nodes: pd.DataFrame,
|
||||||
entities: pd.DataFrame,
|
entities: pd.DataFrame,
|
||||||
@ -196,9 +194,8 @@ async def local_search(
|
|||||||
|
|
||||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||||
|
|
||||||
base_dir = Path(str(root_dir)) / config.storage.base_dir
|
lancedb_dir = Path(config.storage.base_dir) / "lancedb"
|
||||||
resolved_base_dir = resolve_timestamp_path(base_dir)
|
|
||||||
lancedb_dir = resolved_base_dir / "lancedb"
|
|
||||||
vector_store_args.update({"db_uri": str(lancedb_dir)})
|
vector_store_args.update({"db_uri": str(lancedb_dir)})
|
||||||
description_embedding_store = _get_embedding_description_store(
|
description_embedding_store = _get_embedding_description_store(
|
||||||
entities=_entities,
|
entities=_entities,
|
||||||
@ -227,7 +224,6 @@ async def local_search(
|
|||||||
|
|
||||||
@validate_call(config={"arbitrary_types_allowed": True})
|
@validate_call(config={"arbitrary_types_allowed": True})
|
||||||
async def local_search_streaming(
|
async def local_search_streaming(
|
||||||
root_dir: str | None,
|
|
||||||
config: GraphRagConfig,
|
config: GraphRagConfig,
|
||||||
nodes: pd.DataFrame,
|
nodes: pd.DataFrame,
|
||||||
entities: pd.DataFrame,
|
entities: pd.DataFrame,
|
||||||
@ -271,9 +267,8 @@ async def local_search_streaming(
|
|||||||
|
|
||||||
_entities = read_indexer_entities(nodes, entities, community_level)
|
_entities = read_indexer_entities(nodes, entities, community_level)
|
||||||
|
|
||||||
base_dir = Path(str(root_dir)) / config.storage.base_dir
|
lancedb_dir = lancedb_dir = Path(config.storage.base_dir) / "lancedb"
|
||||||
resolved_base_dir = resolve_timestamp_path(base_dir)
|
|
||||||
lancedb_dir = resolved_base_dir / "lancedb"
|
|
||||||
vector_store_args.update({"db_uri": str(lancedb_dir)})
|
vector_store_args.update({"db_uri": str(lancedb_dir)})
|
||||||
description_embedding_store = _get_embedding_description_store(
|
description_embedding_store = _get_embedding_description_store(
|
||||||
entities=_entities,
|
entities=_entities,
|
||||||
|
|||||||
@ -4,17 +4,12 @@
|
|||||||
"""Command line interface for the query module."""
|
"""Command line interface for the query module."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from graphrag.config import (
|
from graphrag.config import load_config, resolve_timestamp_path
|
||||||
GraphRagConfig,
|
|
||||||
create_graphrag_config,
|
|
||||||
)
|
|
||||||
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
|
||||||
from graphrag.index.progress import PrintProgressReporter
|
from graphrag.index.progress import PrintProgressReporter
|
||||||
|
|
||||||
from . import api
|
from . import api
|
||||||
@ -25,7 +20,7 @@ reporter = PrintProgressReporter("")
|
|||||||
def run_global_search(
|
def run_global_search(
|
||||||
config_filepath: str | None,
|
config_filepath: str | None,
|
||||||
data_dir: str | None,
|
data_dir: str | None,
|
||||||
root_dir: str | None,
|
root_dir: str,
|
||||||
community_level: int,
|
community_level: int,
|
||||||
response_type: str,
|
response_type: str,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
@ -35,10 +30,15 @@ def run_global_search(
|
|||||||
|
|
||||||
Loads index files required for global search and calls the Query API.
|
Loads index files required for global search and calls the Query API.
|
||||||
"""
|
"""
|
||||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
root = Path(root_dir).resolve()
|
||||||
data_dir, root_dir, config_filepath
|
config = load_config(root, config_filepath)
|
||||||
|
|
||||||
|
if data_dir:
|
||||||
|
config.storage.base_dir = str(
|
||||||
|
resolve_timestamp_path((root / data_dir).resolve())
|
||||||
)
|
)
|
||||||
data_path = Path(data_dir)
|
|
||||||
|
data_path = Path(config.storage.base_dir).resolve()
|
||||||
|
|
||||||
final_nodes: pd.DataFrame = pd.read_parquet(
|
final_nodes: pd.DataFrame = pd.read_parquet(
|
||||||
data_path / "create_final_nodes.parquet"
|
data_path / "create_final_nodes.parquet"
|
||||||
@ -98,7 +98,7 @@ def run_global_search(
|
|||||||
def run_local_search(
|
def run_local_search(
|
||||||
config_filepath: str | None,
|
config_filepath: str | None,
|
||||||
data_dir: str | None,
|
data_dir: str | None,
|
||||||
root_dir: str | None,
|
root_dir: str,
|
||||||
community_level: int,
|
community_level: int,
|
||||||
response_type: str,
|
response_type: str,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
@ -108,10 +108,15 @@ def run_local_search(
|
|||||||
|
|
||||||
Loads index files required for local search and calls the Query API.
|
Loads index files required for local search and calls the Query API.
|
||||||
"""
|
"""
|
||||||
data_dir, root_dir, config = _configure_paths_and_settings(
|
root = Path(root_dir).resolve()
|
||||||
data_dir, root_dir, config_filepath
|
config = load_config(root, config_filepath)
|
||||||
|
|
||||||
|
if data_dir:
|
||||||
|
config.storage.base_dir = str(
|
||||||
|
resolve_timestamp_path((root / data_dir).resolve())
|
||||||
)
|
)
|
||||||
data_path = Path(data_dir)
|
|
||||||
|
data_path = Path(config.storage.base_dir).resolve()
|
||||||
|
|
||||||
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
|
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
|
||||||
final_community_reports = pd.read_parquet(
|
final_community_reports = pd.read_parquet(
|
||||||
@ -137,7 +142,6 @@ def run_local_search(
|
|||||||
context_data = None
|
context_data = None
|
||||||
get_context_data = True
|
get_context_data = True
|
||||||
async for stream_chunk in api.local_search_streaming(
|
async for stream_chunk in api.local_search_streaming(
|
||||||
root_dir=root_dir,
|
|
||||||
config=config,
|
config=config,
|
||||||
nodes=final_nodes,
|
nodes=final_nodes,
|
||||||
entities=final_entities,
|
entities=final_entities,
|
||||||
@ -163,7 +167,6 @@ def run_local_search(
|
|||||||
# not streaming
|
# not streaming
|
||||||
response, context_data = asyncio.run(
|
response, context_data = asyncio.run(
|
||||||
api.local_search(
|
api.local_search(
|
||||||
root_dir=root_dir,
|
|
||||||
config=config,
|
config=config,
|
||||||
nodes=final_nodes,
|
nodes=final_nodes,
|
||||||
entities=final_entities,
|
entities=final_entities,
|
||||||
@ -180,77 +183,3 @@ def run_local_search(
|
|||||||
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
|
||||||
# External users should use the API directly to get the response and context data.
|
# External users should use the API directly to get the response and context data.
|
||||||
return response, context_data
|
return response, context_data
|
||||||
|
|
||||||
|
|
||||||
def _configure_paths_and_settings(
|
|
||||||
data_dir: str | None,
|
|
||||||
root_dir: str | None,
|
|
||||||
config_filepath: str | None,
|
|
||||||
) -> tuple[str, str | None, GraphRagConfig]:
|
|
||||||
config = _create_graphrag_config(root_dir, config_filepath)
|
|
||||||
if data_dir is None and root_dir is None:
|
|
||||||
msg = "Either data_dir or root_dir must be provided."
|
|
||||||
raise ValueError(msg)
|
|
||||||
if data_dir is None:
|
|
||||||
base_dir = Path(str(root_dir)) / config.storage.base_dir
|
|
||||||
data_dir = str(resolve_timestamp_path(base_dir))
|
|
||||||
return data_dir, root_dir, config
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_data_dir(root: str) -> str:
|
|
||||||
output = Path(root) / "output"
|
|
||||||
# use the latest data-run folder
|
|
||||||
if output.exists():
|
|
||||||
expr = re.compile(r"\d{8}-\d{6}")
|
|
||||||
filtered = [f for f in output.iterdir() if f.is_dir() and expr.match(f.name)]
|
|
||||||
folders = sorted(filtered, key=lambda f: f.name, reverse=True)
|
|
||||||
if len(folders) > 0:
|
|
||||||
folder = folders[0]
|
|
||||||
return str((folder / "artifacts").absolute())
|
|
||||||
msg = f"Could not infer data directory from root={root}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def _create_graphrag_config(
|
|
||||||
root: str | None,
|
|
||||||
config_filepath: str | None,
|
|
||||||
) -> GraphRagConfig:
|
|
||||||
"""Create a GraphRag configuration."""
|
|
||||||
return _read_config_parameters(root or "./", config_filepath)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_config_parameters(root: str, config: str | None):
|
|
||||||
_root = Path(root)
|
|
||||||
settings_yaml = (
|
|
||||||
Path(config)
|
|
||||||
if config and Path(config).suffix in [".yaml", ".yml"]
|
|
||||||
else _root / "settings.yaml"
|
|
||||||
)
|
|
||||||
if not settings_yaml.exists():
|
|
||||||
settings_yaml = _root / "settings.yml"
|
|
||||||
|
|
||||||
if settings_yaml.exists():
|
|
||||||
reporter.info(f"Reading settings from {settings_yaml}")
|
|
||||||
with settings_yaml.open(
|
|
||||||
"rb",
|
|
||||||
) as file:
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
data = yaml.safe_load(file.read().decode(encoding="utf-8", errors="strict"))
|
|
||||||
return create_graphrag_config(data, root)
|
|
||||||
|
|
||||||
settings_json = (
|
|
||||||
Path(config)
|
|
||||||
if config and Path(config).suffix == ".json"
|
|
||||||
else _root / "settings.json"
|
|
||||||
)
|
|
||||||
if settings_json.exists():
|
|
||||||
reporter.info(f"Reading settings from {settings_json}")
|
|
||||||
with settings_json.open("rb") as file:
|
|
||||||
import json
|
|
||||||
|
|
||||||
data = json.loads(file.read().decode(encoding="utf-8", errors="strict"))
|
|
||||||
return create_graphrag_config(data, root)
|
|
||||||
|
|
||||||
reporter.info("Reading settings from environment variables")
|
|
||||||
return create_graphrag_config(root_dir=root)
|
|
||||||
|
|||||||
@ -123,6 +123,7 @@ _convert_global_search_nb = 'jupyter nbconvert --output-dir=docsite/posts/query/
|
|||||||
_semversioner_release = "semversioner release"
|
_semversioner_release = "semversioner release"
|
||||||
_semversioner_changelog = "semversioner changelog > CHANGELOG.md"
|
_semversioner_changelog = "semversioner changelog > CHANGELOG.md"
|
||||||
_semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)"
|
_semversioner_update_toml_version = "update-toml update --path tool.poetry.version --value $(poetry run semversioner current-version)"
|
||||||
|
semversioner_add = "semversioner add-change"
|
||||||
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
|
coverage_report = 'coverage report --omit "**/tests/**" --show-missing'
|
||||||
check_format = 'ruff format . --check --preview'
|
check_format = 'ruff format . --check --preview'
|
||||||
fix = "ruff --preview check --fix ."
|
fix = "ruff --preview check --fix ."
|
||||||
|
|||||||
33
tests/unit/config/test_resolve_timestamp_path.py
Normal file
33
tests/unit/config/test_resolve_timestamp_path.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Copyright (c) 2024 Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from graphrag.config.resolve_timestamp_path import resolve_timestamp_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_timestamp_path_no_timestamp_with_run_id():
|
||||||
|
path = Path("path/to/data")
|
||||||
|
result = resolve_timestamp_path(path, "20240812-121000")
|
||||||
|
assert result == path
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_timestamp_path_no_timestamp_without_run_id():
|
||||||
|
path = Path("path/to/data")
|
||||||
|
result = resolve_timestamp_path(path)
|
||||||
|
assert result == path
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_timestamp_path_with_timestamp_and_run_id():
|
||||||
|
path = Path("some/path/${timestamp}/data")
|
||||||
|
expected = Path("some/path/20240812/data")
|
||||||
|
result = resolve_timestamp_path(path, "20240812")
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_timestamp_path_with_timestamp_and_inferred_directory():
|
||||||
|
cwd = Path(__file__).parent
|
||||||
|
path = cwd / "fixtures/timestamp_dirs/${timestamp}/data"
|
||||||
|
expected = cwd / "fixtures/timestamp_dirs/20240812-120000/data"
|
||||||
|
result = resolve_timestamp_path(path)
|
||||||
|
assert result == expected
|
||||||
@ -1,32 +0,0 @@
|
|||||||
# Copyright (c) 2024 Microsoft Corporation.
|
|
||||||
# Licensed under the MIT License
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from graphrag.query.cli import _infer_data_dir
|
|
||||||
|
|
||||||
|
|
||||||
def test_infer_data_dir():
|
|
||||||
root = "./tests/unit/query/data/defaults"
|
|
||||||
result = Path(_infer_data_dir(root))
|
|
||||||
assert result.parts[-2] == "20240812-121000"
|
|
||||||
|
|
||||||
|
|
||||||
def test_infer_data_dir_ignores_hidden_files():
|
|
||||||
"""A hidden file, starting with '.', will naturally be selected as latest data directory."""
|
|
||||||
root = "./tests/unit/query/data/hidden"
|
|
||||||
result = Path(_infer_data_dir(root))
|
|
||||||
assert result.parts[-2] == "20240812-121000"
|
|
||||||
|
|
||||||
|
|
||||||
def test_infer_data_dir_ignores_non_numeric():
|
|
||||||
root = "./tests/unit/query/data/non-numeric"
|
|
||||||
result = Path(_infer_data_dir(root))
|
|
||||||
assert result.parts[-2] == "20240812-121000"
|
|
||||||
|
|
||||||
|
|
||||||
def test_infer_data_dir_throws_on_no_match():
|
|
||||||
root = "./tests/unit/query/data/empty"
|
|
||||||
with pytest.raises(ValueError): # noqa PT011 (this is what is actually thrown...)
|
|
||||||
_infer_data_dir(root)
|
|
||||||
Loading…
x
Reference in New Issue
Block a user