2024-09-16 12:10:29 -07:00
|
|
|
# Copyright (c) 2024 Microsoft Corporation.
|
|
|
|
# Licensed under the MIT License
|
|
|
|
|
|
|
|
from typing import cast
|
|
|
|
|
|
|
|
import pandas as pd
|
|
|
|
from datashaper import Workflow
|
2024-09-17 17:04:42 -07:00
|
|
|
from pandas.testing import assert_series_equal
|
2024-09-16 12:10:29 -07:00
|
|
|
|
2024-11-15 16:41:10 -08:00
|
|
|
from graphrag.config.create_graphrag_config import create_graphrag_config
|
|
|
|
from graphrag.index.config.workflow import PipelineWorkflowConfig
|
2024-10-24 10:20:03 -07:00
|
|
|
from graphrag.index.context import PipelineRunContext
|
2024-11-15 16:41:10 -08:00
|
|
|
from graphrag.index.create_pipeline_config import create_pipeline_config
|
2024-10-24 10:20:03 -07:00
|
|
|
from graphrag.index.run.utils import create_run_context
|
2024-09-17 10:32:25 -07:00
|
|
|
|
2024-09-30 10:46:07 -07:00
|
|
|
pd.set_option("display.max_columns", None)
|
|
|
|
|
2024-09-16 12:10:29 -07:00
|
|
|
|
|
|
|
def load_input_tables(inputs: list[str]) -> dict[str, pd.DataFrame]:
|
|
|
|
"""Harvest all the referenced input IDs from the workflow being tested and pass them here."""
|
|
|
|
# stick all the inputs in a map - Workflow looks them up by name
|
|
|
|
input_tables: dict[str, pd.DataFrame] = {}
|
2024-09-23 13:24:06 -07:00
|
|
|
|
2024-11-13 15:11:19 -08:00
|
|
|
source = pd.read_parquet("tests/verbs/data/source_documents.parquet")
|
|
|
|
input_tables["source"] = source
|
2024-09-23 13:24:06 -07:00
|
|
|
|
2024-09-16 12:10:29 -07:00
|
|
|
for input in inputs:
|
|
|
|
# remove the workflow: prefix if it exists, because that is not part of the actual table filename
|
|
|
|
name = input.replace("workflow:", "")
|
|
|
|
input_tables[input] = pd.read_parquet(f"tests/verbs/data/{name}.parquet")
|
2024-09-23 16:55:53 -07:00
|
|
|
|
2024-09-16 12:10:29 -07:00
|
|
|
return input_tables
|
|
|
|
|
|
|
|
|
2024-12-05 09:57:26 -08:00
|
|
|
def load_test_table(output: str) -> pd.DataFrame:
|
2024-09-16 12:10:29 -07:00
|
|
|
"""Pass in the workflow output (generally the workflow name)"""
|
|
|
|
return pd.read_parquet(f"tests/verbs/data/{output}.parquet")
|
|
|
|
|
|
|
|
|
2024-09-17 10:32:25 -07:00
|
|
|
def get_config_for_workflow(name: str) -> PipelineWorkflowConfig:
|
|
|
|
"""Instantiates the bare minimum config to get a default workflow config for testing."""
|
|
|
|
config = create_graphrag_config()
|
2024-09-25 16:30:22 -07:00
|
|
|
|
|
|
|
# this flag needs to be set before creating the pipeline config, or the entire covariate workflow will be excluded
|
|
|
|
config.claim_extraction.enabled = True
|
|
|
|
|
2024-09-17 10:32:25 -07:00
|
|
|
pipeline_config = create_pipeline_config(config)
|
2024-09-25 16:30:22 -07:00
|
|
|
|
2024-09-17 10:32:25 -07:00
|
|
|
result = next(conf for conf in pipeline_config.workflows if conf.name == name)
|
2024-10-09 13:46:44 -07:00
|
|
|
|
2024-12-06 14:08:24 -06:00
|
|
|
return cast("PipelineWorkflowConfig", result.config)
|
2024-09-17 10:32:25 -07:00
|
|
|
|
|
|
|
|
2024-09-16 12:10:29 -07:00
|
|
|
async def get_workflow_output(
|
2024-09-30 15:39:42 -07:00
|
|
|
input_tables: dict[str, pd.DataFrame],
|
|
|
|
schema: dict,
|
2024-10-24 10:20:03 -07:00
|
|
|
context: PipelineRunContext | None = None,
|
2024-09-16 12:10:29 -07:00
|
|
|
) -> pd.DataFrame:
|
|
|
|
"""Pass in the input tables, the schema, and the output name"""
|
|
|
|
|
|
|
|
# the bare minimum workflow is the pipeline schema and table context
|
|
|
|
workflow = Workflow(
|
|
|
|
schema=schema,
|
|
|
|
input_tables=input_tables,
|
|
|
|
)
|
|
|
|
|
2024-10-24 10:20:03 -07:00
|
|
|
run_context = context or create_run_context(None, None, None)
|
2024-09-24 15:03:26 -07:00
|
|
|
|
2024-10-24 10:20:03 -07:00
|
|
|
await workflow.run(context=run_context)
|
2024-09-16 12:10:29 -07:00
|
|
|
|
|
|
|
# if there's only one output, it is the default here, no name required
|
2024-12-06 14:08:24 -06:00
|
|
|
return cast("pd.DataFrame", workflow.output())
|
2024-09-16 12:10:29 -07:00
|
|
|
|
|
|
|
|
2024-09-17 17:04:42 -07:00
|
|
|
def compare_outputs(
|
|
|
|
actual: pd.DataFrame, expected: pd.DataFrame, columns: list[str] | None = None
|
|
|
|
) -> None:
|
|
|
|
"""Compare the actual and expected dataframes, optionally specifying columns to compare.
|
2024-10-30 11:59:44 -06:00
|
|
|
This uses assert_series_equal since we are sometimes intentionally omitting columns from the actual output.
|
|
|
|
"""
|
2024-09-17 17:04:42 -07:00
|
|
|
cols = expected.columns if columns is None else columns
|
2024-09-23 13:24:06 -07:00
|
|
|
|
2024-12-06 14:08:24 -06:00
|
|
|
assert len(actual) == len(expected), (
|
|
|
|
f"Expected: {len(expected)} rows, Actual: {len(actual)} rows"
|
|
|
|
)
|
2024-09-23 13:24:06 -07:00
|
|
|
|
|
|
|
for column in cols:
|
|
|
|
assert column in actual.columns
|
|
|
|
try:
|
2024-09-17 17:04:42 -07:00
|
|
|
# dtypes can differ since the test data is read from parquet and our workflow runs in memory
|
2024-09-25 17:35:44 -07:00
|
|
|
assert_series_equal(
|
|
|
|
actual[column], expected[column], check_dtype=False, check_index=False
|
|
|
|
)
|
2024-09-23 13:24:06 -07:00
|
|
|
except AssertionError:
|
|
|
|
print("Expected:")
|
|
|
|
print(expected[column])
|
|
|
|
print("Actual:")
|
2024-09-25 17:35:44 -07:00
|
|
|
print(actual[column])
|
2024-09-23 13:24:06 -07:00
|
|
|
raise
|