Attempt to update smoke tests for multi index query.

This commit is contained in:
Doug Orbaker 2025-06-04 14:05:30 +00:00
parent b409cef0ce
commit a63eed0b76
3 changed files with 279 additions and 0 deletions

61
tests/fixtures/text/settings_input2.yml vendored Normal file
View File

@ -0,0 +1,61 @@
models:
default_chat_model:
azure_auth_type: api_key
type: ${GRAPHRAG_LLM_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
api_version: ${GRAPHRAG_API_VERSION}
deployment_name: ${GRAPHRAG_LLM_DEPLOYMENT_NAME}
model: ${GRAPHRAG_LLM_MODEL}
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
requests_per_minute: ${GRAPHRAG_LLM_RPM}
model_supports_json: true
concurrent_requests: 50
async_mode: threaded
default_embedding_model:
azure_auth_type: api_key
type: ${GRAPHRAG_EMBEDDING_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
api_version: ${GRAPHRAG_API_VERSION}
deployment_name: ${GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME}
model: ${GRAPHRAG_EMBEDDING_MODEL}
tokens_per_minute: null
requests_per_minute: null
concurrent_requests: 50
async_mode: threaded
vector_store:
default_vector_store2:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "simple_text_ci2"
input:
type: file # or blob
file_type: text # [csv, text, json]
base_dir: "./tests/fixtures/text/input2"
file_encoding: utf-8
file_pattern: ".*\\.txt$$"
output:
type: file # [file, blob, cosmosdb]
base_dir: "./tests/fixtures/text/output2"
extract_claims:
enabled: true
community_reports:
prompt: "prompts/community_report.txt"
max_length: 2000
max_input_length: 8000
snapshots:
embeddings: True
drift_search:
n_depth: 1
drift_k_followups: 3
primer_folds: 3

62
tests/fixtures/text/settings_miq.yml vendored Normal file
View File

@ -0,0 +1,62 @@
models:
default_chat_model:
azure_auth_type: api_key
type: ${GRAPHRAG_LLM_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
api_version: ${GRAPHRAG_API_VERSION}
deployment_name: ${GRAPHRAG_LLM_DEPLOYMENT_NAME}
model: ${GRAPHRAG_LLM_MODEL}
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
requests_per_minute: ${GRAPHRAG_LLM_RPM}
model_supports_json: true
concurrent_requests: 50
async_mode: threaded
default_embedding_model:
azure_auth_type: api_key
type: ${GRAPHRAG_EMBEDDING_TYPE}
api_key: ${GRAPHRAG_API_KEY}
api_base: ${GRAPHRAG_API_BASE}
api_version: ${GRAPHRAG_API_VERSION}
deployment_name: ${GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME}
model: ${GRAPHRAG_EMBEDDING_MODEL}
tokens_per_minute: null
requests_per_minute: null
concurrent_requests: 50
async_mode: threaded
outputs:
index1:
type: file # [file, blob, cosmosdb]
base_dir: "./tests/fixtures/text/output"
index2:
type: file # [file, blob, cosmosdb]
base_dir: "./tests/fixtures/text/output2"
vector_store:
index1:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "simple_text_ci"
index2:
type: "azure_ai_search"
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
api_key: ${AZURE_AI_SEARCH_API_KEY}
container_name: "simple_text_ci2"
extract_claims:
enabled: true
community_reports:
prompt: "prompts/community_report.txt"
max_length: 2000
max_input_length: 8000
snapshots:
embeddings: True
drift_search:
n_depth: 1
drift_k_followups: 3
primer_folds: 3

View File

@ -83,6 +83,7 @@ def cleanup(skip: bool = False):
root = Path(kwargs["input_path"])
shutil.rmtree(root / "output", ignore_errors=True)
shutil.rmtree(root / "cache", ignore_errors=True)
shutil.rmtree(root / "output2", ignore_errors=True)
return wrapper
@ -118,6 +119,35 @@ async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], Non
return lambda: input_storage._delete_container() # noqa: SLF001
async def prepare_azurite_data2(input_path: str, azure: dict) -> Callable[[], None]:
"""Prepare the data for the Azurite tests."""
input_container = azure["input_container"]
input_base_dir = azure.get("input_base_dir")
root = Path(input_path)
input_storage = BlobPipelineStorage(
connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING,
container_name=input_container,
)
# Bounce the container if it exists to clear out old run data
input_storage._delete_container() # noqa: SLF001
input_storage._create_container() # noqa: SLF001
# Upload data files
txt_files = list((root / "input2").glob("*.txt"))
csv_files = list((root / "input2").glob("*.csv"))
data_files = txt_files + csv_files
for data_file in data_files:
text = data_file.read_bytes().decode("utf-8")
file_path = (
str(Path(input_base_dir) / data_file.name)
if input_base_dir
else data_file.name
)
await input_storage.set(file_path, text, encoding="utf-8")
return lambda: input_storage._delete_container() # noqa: SLF001
class TestIndexer:
params: ClassVar[dict[str, list[tuple[str, dict[str, Any]]]]] = {
@ -151,6 +181,35 @@ class TestIndexer:
f"Indexer failed with return code: {completion.returncode}"
)
def __run_indexer2(
self,
root: Path,
input_file_type: str,
):
command = [
"poetry",
"run",
"poe",
"index",
"--verbose" if debug else None,
"--root",
root.resolve().as_posix(),
"--logger",
"print",
"--method",
"standard",
"--config",
root.resolve().as_posix() + "/settings_input2.yml",
]
command = [arg for arg in command if arg]
log.info("running command ", " ".join(command))
completion = subprocess.run(
command, env={**os.environ, "GRAPHRAG_INPUT_FILE_TYPE": input_file_type}
)
assert completion.returncode == 0, (
f"Indexer failed with return code: {completion.returncode}"
)
def __assert_indexer_outputs(
self, root: Path, workflow_config: dict[str, dict[str, Any]]
):
@ -202,6 +261,58 @@ class TestIndexer:
f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
)
def __assert_indexer_outputs2(
self, root: Path, workflow_config: dict[str, dict[str, Any]]
):
output_path = root / "output2"
assert output_path.exists(), "output2 folder does not exist"
# Check stats for all workflow
stats = json.loads((output_path / "stats.json").read_bytes().decode("utf-8"))
# Check all workflows run
expected_workflows = set(workflow_config.keys())
workflows = set(stats["workflows"].keys())
assert workflows == expected_workflows, (
f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}"
)
# [OPTIONAL] Check runtime
for workflow, config in workflow_config.items():
# Check expected artifacts
workflow_artifacts = config.get("expected_artifacts", [])
# Check max runtime
max_runtime = config.get("max_runtime", None)
if max_runtime:
assert stats["workflows"][workflow]["overall"] <= max_runtime, (
f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}"
)
# Check expected artifacts
for artifact in workflow_artifacts:
if artifact.endswith(".parquet"):
output_df = pd.read_parquet(output_path / artifact)
# Check number of rows between range
assert (
config["row_range"][0]
<= len(output_df)
<= config["row_range"][1]
), (
f"Expected between {config['row_range'][0]} and {config['row_range'][1]}, found: {len(output_df)} for file: {artifact}"
)
# Get non-nan rows
nan_df = output_df.loc[
:,
~output_df.columns.isin(config.get("nan_allowed_columns", [])),
]
nan_df = nan_df[nan_df.isna().any(axis=1)]
assert len(nan_df) == 0, (
f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
)
def __run_query(self, root: Path, query_config: dict[str, str]):
command = [
"poetry",
@ -221,6 +332,27 @@ class TestIndexer:
log.info("running command ", " ".join(command))
return subprocess.run(command, capture_output=True, text=True)
def __run_multi_index_query(self, root: Path, query_config: dict[str, str]):
command = [
"poetry",
"run",
"poe",
"query",
"--root",
root.resolve().as_posix(),
"--method",
query_config["method"],
"--community-level",
str(query_config.get("community_level", 2)),
"--query",
query_config["query"],
"--config",
root.resolve().as_posix() + "/settings_miq.yml",
]
log.info("running command ", " ".join(command))
return subprocess.run(command, capture_output=True, text=True)
@cleanup(skip=debug)
@mock.patch.dict(
os.environ,
@ -262,10 +394,25 @@ class TestIndexer:
if dispose is not None:
dispose()
dispose2 = None
if azure is not None:
dispose2 = asyncio.run(prepare_azurite_data2(input_path, azure))
print("running indexer")
self.__run_indexer2(root, input_file_type)
print("indexer complete")
if dispose2 is not None:
dispose2()
if not workflow_config.get("skip_assert"):
print("performing dataset assertions")
self.__assert_indexer_outputs(root, workflow_config)
if not workflow_config.get("skip_assert"):
print("performing dataset assertions")
self.__assert_indexer_outputs2(root, workflow_config)
print("running queries")
for query in query_config:
result = self.__run_query(root, query)
@ -274,3 +421,12 @@ class TestIndexer:
assert result.returncode == 0, "Query failed"
assert result.stdout is not None, "Query returned no output"
assert len(result.stdout) > 0, "Query returned empty output"
print("running multi_index_queries")
for query in query_config:
result = self.__run_multi_index_query(root, query)
print(f"Query: {query}\nResponse: {result.stdout}")
assert result.returncode == 0, "Query failed"
assert result.stdout is not None, "Query returned no output"
assert len(result.stdout) > 0, "Query returned empty output"