mirror of
https://github.com/microsoft/graphrag.git
synced 2025-06-26 23:19:58 +00:00
Attempt to update smoke tests for multi index query.
This commit is contained in:
parent
b409cef0ce
commit
a63eed0b76
61
tests/fixtures/text/settings_input2.yml
vendored
Normal file
61
tests/fixtures/text/settings_input2.yml
vendored
Normal file
@ -0,0 +1,61 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
azure_auth_type: api_key
|
||||
type: ${GRAPHRAG_LLM_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
api_version: ${GRAPHRAG_API_VERSION}
|
||||
deployment_name: ${GRAPHRAG_LLM_DEPLOYMENT_NAME}
|
||||
model: ${GRAPHRAG_LLM_MODEL}
|
||||
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
||||
model_supports_json: true
|
||||
concurrent_requests: 50
|
||||
async_mode: threaded
|
||||
default_embedding_model:
|
||||
azure_auth_type: api_key
|
||||
type: ${GRAPHRAG_EMBEDDING_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
api_version: ${GRAPHRAG_API_VERSION}
|
||||
deployment_name: ${GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME}
|
||||
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
concurrent_requests: 50
|
||||
async_mode: threaded
|
||||
|
||||
vector_store:
|
||||
default_vector_store2:
|
||||
type: "azure_ai_search"
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
container_name: "simple_text_ci2"
|
||||
|
||||
input:
|
||||
type: file # or blob
|
||||
file_type: text # [csv, text, json]
|
||||
base_dir: "./tests/fixtures/text/input2"
|
||||
file_encoding: utf-8
|
||||
file_pattern: ".*\\.txt$$"
|
||||
|
||||
output:
|
||||
type: file # [file, blob, cosmosdb]
|
||||
base_dir: "./tests/fixtures/text/output2"
|
||||
|
||||
|
||||
extract_claims:
|
||||
enabled: true
|
||||
|
||||
community_reports:
|
||||
prompt: "prompts/community_report.txt"
|
||||
max_length: 2000
|
||||
max_input_length: 8000
|
||||
|
||||
snapshots:
|
||||
embeddings: True
|
||||
|
||||
drift_search:
|
||||
n_depth: 1
|
||||
drift_k_followups: 3
|
||||
primer_folds: 3
|
62
tests/fixtures/text/settings_miq.yml
vendored
Normal file
62
tests/fixtures/text/settings_miq.yml
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
models:
|
||||
default_chat_model:
|
||||
azure_auth_type: api_key
|
||||
type: ${GRAPHRAG_LLM_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
api_version: ${GRAPHRAG_API_VERSION}
|
||||
deployment_name: ${GRAPHRAG_LLM_DEPLOYMENT_NAME}
|
||||
model: ${GRAPHRAG_LLM_MODEL}
|
||||
tokens_per_minute: ${GRAPHRAG_LLM_TPM}
|
||||
requests_per_minute: ${GRAPHRAG_LLM_RPM}
|
||||
model_supports_json: true
|
||||
concurrent_requests: 50
|
||||
async_mode: threaded
|
||||
default_embedding_model:
|
||||
azure_auth_type: api_key
|
||||
type: ${GRAPHRAG_EMBEDDING_TYPE}
|
||||
api_key: ${GRAPHRAG_API_KEY}
|
||||
api_base: ${GRAPHRAG_API_BASE}
|
||||
api_version: ${GRAPHRAG_API_VERSION}
|
||||
deployment_name: ${GRAPHRAG_EMBEDDING_DEPLOYMENT_NAME}
|
||||
model: ${GRAPHRAG_EMBEDDING_MODEL}
|
||||
tokens_per_minute: null
|
||||
requests_per_minute: null
|
||||
concurrent_requests: 50
|
||||
async_mode: threaded
|
||||
|
||||
outputs:
|
||||
index1:
|
||||
type: file # [file, blob, cosmosdb]
|
||||
base_dir: "./tests/fixtures/text/output"
|
||||
index2:
|
||||
type: file # [file, blob, cosmosdb]
|
||||
base_dir: "./tests/fixtures/text/output2"
|
||||
|
||||
vector_store:
|
||||
index1:
|
||||
type: "azure_ai_search"
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
container_name: "simple_text_ci"
|
||||
index2:
|
||||
type: "azure_ai_search"
|
||||
url: ${AZURE_AI_SEARCH_URL_ENDPOINT}
|
||||
api_key: ${AZURE_AI_SEARCH_API_KEY}
|
||||
container_name: "simple_text_ci2"
|
||||
|
||||
extract_claims:
|
||||
enabled: true
|
||||
|
||||
community_reports:
|
||||
prompt: "prompts/community_report.txt"
|
||||
max_length: 2000
|
||||
max_input_length: 8000
|
||||
|
||||
snapshots:
|
||||
embeddings: True
|
||||
|
||||
drift_search:
|
||||
n_depth: 1
|
||||
drift_k_followups: 3
|
||||
primer_folds: 3
|
@ -83,6 +83,7 @@ def cleanup(skip: bool = False):
|
||||
root = Path(kwargs["input_path"])
|
||||
shutil.rmtree(root / "output", ignore_errors=True)
|
||||
shutil.rmtree(root / "cache", ignore_errors=True)
|
||||
shutil.rmtree(root / "output2", ignore_errors=True)
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -118,6 +119,35 @@ async def prepare_azurite_data(input_path: str, azure: dict) -> Callable[[], Non
|
||||
|
||||
return lambda: input_storage._delete_container() # noqa: SLF001
|
||||
|
||||
async def prepare_azurite_data2(input_path: str, azure: dict) -> Callable[[], None]:
|
||||
"""Prepare the data for the Azurite tests."""
|
||||
input_container = azure["input_container"]
|
||||
input_base_dir = azure.get("input_base_dir")
|
||||
|
||||
root = Path(input_path)
|
||||
input_storage = BlobPipelineStorage(
|
||||
connection_string=WELL_KNOWN_AZURITE_CONNECTION_STRING,
|
||||
container_name=input_container,
|
||||
)
|
||||
# Bounce the container if it exists to clear out old run data
|
||||
input_storage._delete_container() # noqa: SLF001
|
||||
input_storage._create_container() # noqa: SLF001
|
||||
|
||||
# Upload data files
|
||||
txt_files = list((root / "input2").glob("*.txt"))
|
||||
csv_files = list((root / "input2").glob("*.csv"))
|
||||
data_files = txt_files + csv_files
|
||||
for data_file in data_files:
|
||||
text = data_file.read_bytes().decode("utf-8")
|
||||
file_path = (
|
||||
str(Path(input_base_dir) / data_file.name)
|
||||
if input_base_dir
|
||||
else data_file.name
|
||||
)
|
||||
await input_storage.set(file_path, text, encoding="utf-8")
|
||||
|
||||
return lambda: input_storage._delete_container() # noqa: SLF001
|
||||
|
||||
|
||||
class TestIndexer:
|
||||
params: ClassVar[dict[str, list[tuple[str, dict[str, Any]]]]] = {
|
||||
@ -151,6 +181,35 @@ class TestIndexer:
|
||||
f"Indexer failed with return code: {completion.returncode}"
|
||||
)
|
||||
|
||||
def __run_indexer2(
|
||||
self,
|
||||
root: Path,
|
||||
input_file_type: str,
|
||||
):
|
||||
command = [
|
||||
"poetry",
|
||||
"run",
|
||||
"poe",
|
||||
"index",
|
||||
"--verbose" if debug else None,
|
||||
"--root",
|
||||
root.resolve().as_posix(),
|
||||
"--logger",
|
||||
"print",
|
||||
"--method",
|
||||
"standard",
|
||||
"--config",
|
||||
root.resolve().as_posix() + "/settings_input2.yml",
|
||||
]
|
||||
command = [arg for arg in command if arg]
|
||||
log.info("running command ", " ".join(command))
|
||||
completion = subprocess.run(
|
||||
command, env={**os.environ, "GRAPHRAG_INPUT_FILE_TYPE": input_file_type}
|
||||
)
|
||||
assert completion.returncode == 0, (
|
||||
f"Indexer failed with return code: {completion.returncode}"
|
||||
)
|
||||
|
||||
def __assert_indexer_outputs(
|
||||
self, root: Path, workflow_config: dict[str, dict[str, Any]]
|
||||
):
|
||||
@ -202,6 +261,58 @@ class TestIndexer:
|
||||
f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
|
||||
)
|
||||
|
||||
def __assert_indexer_outputs2(
|
||||
self, root: Path, workflow_config: dict[str, dict[str, Any]]
|
||||
):
|
||||
output_path = root / "output2"
|
||||
|
||||
assert output_path.exists(), "output2 folder does not exist"
|
||||
|
||||
# Check stats for all workflow
|
||||
stats = json.loads((output_path / "stats.json").read_bytes().decode("utf-8"))
|
||||
|
||||
# Check all workflows run
|
||||
expected_workflows = set(workflow_config.keys())
|
||||
workflows = set(stats["workflows"].keys())
|
||||
assert workflows == expected_workflows, (
|
||||
f"Workflows missing from stats.json: {expected_workflows - workflows}. Unexpected workflows in stats.json: {workflows - expected_workflows}"
|
||||
)
|
||||
|
||||
# [OPTIONAL] Check runtime
|
||||
for workflow, config in workflow_config.items():
|
||||
# Check expected artifacts
|
||||
workflow_artifacts = config.get("expected_artifacts", [])
|
||||
# Check max runtime
|
||||
max_runtime = config.get("max_runtime", None)
|
||||
if max_runtime:
|
||||
assert stats["workflows"][workflow]["overall"] <= max_runtime, (
|
||||
f"Expected max runtime of {max_runtime}, found: {stats['workflows'][workflow]['overall']} for workflow: {workflow}"
|
||||
)
|
||||
# Check expected artifacts
|
||||
for artifact in workflow_artifacts:
|
||||
if artifact.endswith(".parquet"):
|
||||
output_df = pd.read_parquet(output_path / artifact)
|
||||
|
||||
# Check number of rows between range
|
||||
assert (
|
||||
config["row_range"][0]
|
||||
<= len(output_df)
|
||||
<= config["row_range"][1]
|
||||
), (
|
||||
f"Expected between {config['row_range'][0]} and {config['row_range'][1]}, found: {len(output_df)} for file: {artifact}"
|
||||
)
|
||||
|
||||
# Get non-nan rows
|
||||
nan_df = output_df.loc[
|
||||
:,
|
||||
~output_df.columns.isin(config.get("nan_allowed_columns", [])),
|
||||
]
|
||||
nan_df = nan_df[nan_df.isna().any(axis=1)]
|
||||
assert len(nan_df) == 0, (
|
||||
f"Found {len(nan_df)} rows with NaN values for file: {artifact} on columns: {nan_df.columns[nan_df.isna().any()].tolist()}"
|
||||
)
|
||||
|
||||
|
||||
def __run_query(self, root: Path, query_config: dict[str, str]):
|
||||
command = [
|
||||
"poetry",
|
||||
@ -221,6 +332,27 @@ class TestIndexer:
|
||||
log.info("running command ", " ".join(command))
|
||||
return subprocess.run(command, capture_output=True, text=True)
|
||||
|
||||
def __run_multi_index_query(self, root: Path, query_config: dict[str, str]):
|
||||
command = [
|
||||
"poetry",
|
||||
"run",
|
||||
"poe",
|
||||
"query",
|
||||
"--root",
|
||||
root.resolve().as_posix(),
|
||||
"--method",
|
||||
query_config["method"],
|
||||
"--community-level",
|
||||
str(query_config.get("community_level", 2)),
|
||||
"--query",
|
||||
query_config["query"],
|
||||
"--config",
|
||||
root.resolve().as_posix() + "/settings_miq.yml",
|
||||
]
|
||||
|
||||
log.info("running command ", " ".join(command))
|
||||
return subprocess.run(command, capture_output=True, text=True)
|
||||
|
||||
@cleanup(skip=debug)
|
||||
@mock.patch.dict(
|
||||
os.environ,
|
||||
@ -262,10 +394,25 @@ class TestIndexer:
|
||||
if dispose is not None:
|
||||
dispose()
|
||||
|
||||
dispose2 = None
|
||||
if azure is not None:
|
||||
dispose2 = asyncio.run(prepare_azurite_data2(input_path, azure))
|
||||
|
||||
print("running indexer")
|
||||
self.__run_indexer2(root, input_file_type)
|
||||
print("indexer complete")
|
||||
|
||||
if dispose2 is not None:
|
||||
dispose2()
|
||||
|
||||
if not workflow_config.get("skip_assert"):
|
||||
print("performing dataset assertions")
|
||||
self.__assert_indexer_outputs(root, workflow_config)
|
||||
|
||||
if not workflow_config.get("skip_assert"):
|
||||
print("performing dataset assertions")
|
||||
self.__assert_indexer_outputs2(root, workflow_config)
|
||||
|
||||
print("running queries")
|
||||
for query in query_config:
|
||||
result = self.__run_query(root, query)
|
||||
@ -274,3 +421,12 @@ class TestIndexer:
|
||||
assert result.returncode == 0, "Query failed"
|
||||
assert result.stdout is not None, "Query returned no output"
|
||||
assert len(result.stdout) > 0, "Query returned empty output"
|
||||
|
||||
print("running multi_index_queries")
|
||||
for query in query_config:
|
||||
result = self.__run_multi_index_query(root, query)
|
||||
print(f"Query: {query}\nResponse: {result.stdout}")
|
||||
|
||||
assert result.returncode == 0, "Query failed"
|
||||
assert result.stdout is not None, "Query returned no output"
|
||||
assert len(result.stdout) > 0, "Query returned empty output"
|
||||
|
Loading…
x
Reference in New Issue
Block a user