mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00

* Move verb tests to regular CI * Clean up env vars * Update smoke runtime expectations * Rework artifact assertions * Fix plural in name * remove redundant artifact len check * Remove redundant artifact len check * Adjust graph output expectations * Update community expectations * Include all workflow output * Adjust text unit expectations * Adjust assertions per dataset * Fix test config param name * Update nan allowed for optional model fields --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
275 lines
9.6 KiB
Python
275 lines
9.6 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
from collections.abc import Callable
|
|
from functools import wraps
|
|
from pathlib import Path
|
|
from typing import Any, ClassVar
|
|
from unittest import mock
|
|
|
|
import pandas as pd
|
|
import pytest
|
|
|
|
from graphrag.query.context_builder.community_context import (
|
|
NO_COMMUNITY_RECORDS_WARNING,
|
|
)
|
|
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
debug = os.environ.get("DEBUG") is not None
|
|
gh_pages = os.environ.get("GH_PAGES") is not None
|
|
|
|
# cspell:disable-next-line well-known-key
|
|
WELL_KNOWN_AZURITE_CONNECTION_STRING = "DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://127.0.0.1:10000/devstoreaccount1"
|
|
|
|
KNOWN_WARNINGS = [NO_COMMUNITY_RECORDS_WARNING]
|
|
|
|
|
|
def _load_fixtures():
|
|
"""Load all fixtures from the tests/data folder."""
|
|
params = []
|
|
fixtures_path = Path("./tests/fixtures/")
|
|
# use the min-csv smoke test to hydrate the docsite parquet artifacts (see gh-pages.yml)
|
|
subfolders = ["min-csv"] if gh_pages else sorted(os.listdir(fixtures_path))
|
|
|
|
for subfolder in subfolders:
|
|
if not os.path.isdir(fixtures_path / subfolder):
|
|
continue
|
|
|
|
config_file = fixtures_path / subfolder / "config.json"
|
|
params.append((subfolder, json.loads(config_file.read_bytes().decode("utf-8"))))
|
|
|
|
return params[1:] # disable azure blob connection test
|
|
|
|
|
|
def pytest_generate_tests(metafunc):
|
|
"""Generate tests for all test functions in this module."""
|
|
run_slow = metafunc.config.getoption("run_slow")
|
|
configs = metafunc.cls.params[metafunc.function.__name__]
|
|
|
|
if not run_slow:
|
|
# Only run tests that are not marked as slow
|
|
configs = [config for config in configs if not config[1].get("slow", False)]
|
|
|
|
funcarglist = [params[1] for params in configs]
|
|
id_list = [params[0] for params in configs]
|
|
|
|
argnames = sorted(arg for arg in funcarglist[0] if arg != "slow")
|
|
metafunc.parametrize(
|
|
argnames,
|
|
[[funcargs[name] for name in argnames] for funcargs in funcarglist],
|
|
ids=id_list,
|
|
)
|
|
|
|
|
|
def cleanup(skip: bool = False):
|
|
"""Decorator to cleanup the output and cache folders after each test."""
|
|
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except AssertionError:
|
|
raise
|
|
finally:
|
|
if not skip:
|
|
root = Path(kwargs["input_path"])
|
|
shutil.rmtree(root / "output", ignore_errors=True)
|
|
shutil.rmtree(root / "cache", ignore_errors=True)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], None]:
|
|
"""Prepare the data for the Azurite tests."""
|
|
input_container = azure["input_container"]
|
|
input_base_dir = azure.get("input_base_dir")
|
|
|
|
root = Path(input_path)
|
|
input_storage = BlobPipelineStorage(
|
|
connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING,
|
|
container_name=input_container,
|
|
)
|
|
# Bounce the container if it exists to clear out old run data
|
|
input_storage._delete_container() # noqa: SLF001
|
|
input_storage._create_container() # noqa: SLF001
|
|
|
|
# Upload data files
|
|
txt_files = list((root / "input").glob("*.txt"))
|
|
csv_files = list((root / "input").glob("*.csv"))
|
|
data_files = txt_files + csv_files
|
|
for data_file in data_files:
|
|
text = data_file.read_bytes().decode("utf-8")
|
|
file_path = (
|
|
str(Path(input_base_dir) / data_file.name)
|
|
if input_base_dir
|
|
else data_file.name
|
|
)
|
|
await input_storage.set(file_path, text, encoding="utf-8")
|
|
|
|
return lambda: input_storage._delete_container() # noqa: SLF001
|
|
|
|
|
|
class TestIndexer:
|
|
params: ClassVar[dict[str, list[tuple[str, dict[str, Any]]]]] = {
|
|
"test_fixture": _load_fixtures()
|
|
}
|
|
|
|
def __run_indexer(
|
|
self,
|
|
root: Path,
|
|
input_file_type: str,
|
|
):
|
|
command = [
|
|
"poetry",
|
|
"run",
|
|
"poe",
|
|
"index",
|
|
"--verbose" if debug else None,
|
|
"--root",
|
|
root.resolve().as_posix(),
|
|
"--logger",
|
|
"print",
|
|
]
|
|
command = [arg for arg in command if arg]
|
|
log.info("running command ", " ".join(command))
|
|
completion = subprocess.run(
|
|
command, env={**os.environ, "GRAPHRAG_INPUT_FILE_TYPE": input_file_type}
|
|
)
|
|
assert completion.returncode == 0, (
|
|
f"Indexer failed with return code: {completion.returncode}"
|
|
)
|
|
|
|
def __assert_indexer_outputs(
|
|
self, root: Path, workflow_config: dict[str, dict[str, Any]]
|
|
):
|
|
output_path = root / "output"
|
|
|
|
assert output_path.exists(), "output folder does not exist"
|
|
|
|
# Check stats for all workflow
|
|
stats = json.loads((output_path / "stats.json").read_bytes().decode("utf-8"))
|
|
|
|
# Check all workflows run
|
|
expected_workflows = set(workflow_config.keys())
|
|
workflows = set(stats["workflows"].keys())
|
|
assert workflows == expected_workflows, (
|
|
f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}"
|
|
)
|
|
|
|
# [OPTIONAL] Check runtime
|
|
for workflow, config in workflow_config.items():
|
|
# Check expected artifacts
|
|
workflow_artifacts = config.get("expected_artifacts", [])
|
|
# Check max runtime
|
|
max_runtime = config.get("max_runtime", None)
|
|
if max_runtime:
|
|
assert stats["workflows"][workflow]["overall"] <= max_runtime, (
|
|
f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}"
|
|
)
|
|
# Check expected artifacts
|
|
for artifact in workflow_artifacts:
|
|
if artifact.endswith(".parquet"):
|
|
output_df = pd.read_parquet(output_path / artifact)
|
|
|
|
# Check number of rows between range
|
|
assert (
|
|
config["row_range"][0]
|
|
<= len(output_df)
|
|
<= config["row_range"][1]
|
|
), (
|
|
f"Expected between {config['row_range'][0]} and {config['row_range'][1]}, found: {len(output_df)} for file: {artifact}"
|
|
)
|
|
|
|
# Get non-nan rows
|
|
nan_df = output_df.loc[
|
|
:,
|
|
~output_df.columns.isin(config.get("nan_allowed_columns", [])),
|
|
]
|
|
nan_df = nan_df[nan_df.isna().any(axis=1)]
|
|
assert len(nan_df) == 0, (
|
|
f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
|
|
)
|
|
|
|
def __run_query(self, root: Path, query_config: dict[str, str]):
|
|
command = [
|
|
"poetry",
|
|
"run",
|
|
"poe",
|
|
"query",
|
|
"--root",
|
|
root.resolve().as_posix(),
|
|
"--method",
|
|
query_config["method"],
|
|
"--community-level",
|
|
str(query_config.get("community_level", 2)),
|
|
"--query",
|
|
query_config["query"],
|
|
]
|
|
|
|
log.info("running command ", " ".join(command))
|
|
return subprocess.run(command, capture_output=True, text=True)
|
|
|
|
@cleanup(skip=debug)
|
|
@mock.patch.dict(
|
|
os.environ,
|
|
{
|
|
**os.environ,
|
|
"BLOB_STORAGE_CONNECTION_STRING": os.getenv(
|
|
"GRAPHRAG_CACHE_CONNECTION_STRING", WELL_KNOWN_AZURITE_CONNECTION_STRING
|
|
),
|
|
"LOCAL_BLOB_STORAGE_CONNECTION_STRING": WELL_KNOWN_AZURITE_CONNECTION_STRING,
|
|
"GRAPHRAG_CHUNK_SIZE": "1200",
|
|
"GRAPHRAG_CHUNK_OVERLAP": "0",
|
|
"AZURE_AI_SEARCH_URL_ENDPOINT": os.getenv("AZURE_AI_SEARCH_URL_ENDPOINT"),
|
|
"AZURE_AI_SEARCH_API_KEY": os.getenv("AZURE_AI_SEARCH_API_KEY"),
|
|
},
|
|
clear=True,
|
|
)
|
|
@pytest.mark.timeout(800)
|
|
def test_fixture(
|
|
self,
|
|
input_path: str,
|
|
input_file_type: str,
|
|
workflow_config: dict[str, dict[str, Any]],
|
|
query_config: list[dict[str, str]],
|
|
):
|
|
if workflow_config.get("skip"):
|
|
print(f"skipping smoke test {input_path})")
|
|
return
|
|
|
|
azure = workflow_config.get("azure")
|
|
root = Path(input_path)
|
|
dispose = None
|
|
if azure is not None:
|
|
dispose = asyncio.run(prepare_azurite_data(input_path, azure))
|
|
|
|
print("running indexer")
|
|
self.__run_indexer(root, input_file_type)
|
|
print("indexer complete")
|
|
|
|
if dispose is not None:
|
|
dispose()
|
|
|
|
if not workflow_config.get("skip_assert"):
|
|
print("performing dataset assertions")
|
|
self.__assert_indexer_outputs(root, workflow_config)
|
|
|
|
print("running queries")
|
|
for query in query_config:
|
|
result = self.__run_query(root, query)
|
|
print(f"Query: {query}\nResponse: {result.stdout}")
|
|
|
|
assert result.returncode == 0, "Query failed"
|
|
assert result.stdout is not None, "Query returned no output"
|
|
assert len(result.stdout) > 0, "Query returned empty output"
|